mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 23:42:13 +08:00
Compare commits
351 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3a9f5bb88 | ||
|
|
7eb0415a8a | ||
|
|
bdbc8fa08f | ||
|
|
63f3af0f94 | ||
|
|
686f890fbf | ||
|
|
220fbe6544 | ||
|
|
ae44a94325 | ||
|
|
3718d6dcd4 | ||
|
|
90b3838173 | ||
|
|
19d3ecc76f | ||
|
|
6fba4ebb13 | ||
|
|
c31974c913 | ||
|
|
6177fa5dd8 | ||
|
|
cfe72159d0 | ||
|
|
8321e4a647 | ||
|
|
3084330d0c | ||
|
|
b566649e79 | ||
|
|
10a6180e4a | ||
|
|
cbe9e78977 | ||
|
|
74145b1f39 | ||
|
|
359e56751b | ||
|
|
5899784aa4 | ||
|
|
9e8959c56d | ||
|
|
1bff2292a6 | ||
|
|
cf9247754e | ||
|
|
eefab15958 | ||
|
|
0e23732631 | ||
|
|
37c044fb4b | ||
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
e4a4dfd038 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
4f2a97073e | ||
|
|
7407e3b45d | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
2666422b99 | ||
|
|
e6d59216d4 | ||
|
|
4e8615f276 | ||
|
|
91e4d95660 | ||
|
|
45456fa24c | ||
|
|
4588258d80 | ||
|
|
c12e48f966 | ||
|
|
ec8f50a658 | ||
|
|
99c9191784 | ||
|
|
6bb02d141f | ||
|
|
07bb2a5f3f | ||
|
|
417861a48e | ||
|
|
b7e878de64 | ||
|
|
05edb5514b | ||
|
|
e90ec847b6 | ||
|
|
6344fa2a86 | ||
|
|
7e288acc90 | ||
|
|
29b0e4a8a5 | ||
|
|
27ff222cfb | ||
|
|
11f7b83522 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
1ee984478f | ||
|
|
fd693dc526 | ||
|
|
e73531ce9b | ||
|
|
53ad1645cf | ||
|
|
ecea13757b | ||
|
|
af9c4a7dd0 | ||
|
|
80d8d6c3bc | ||
|
|
d648811233 | ||
|
|
34695acb85 | ||
|
|
a63de12182 | ||
|
|
f16910d616 | ||
|
|
64b3f3cec1 | ||
|
|
6a685727d0 | ||
|
|
32d25f76fc | ||
|
|
69cafe8674 | ||
|
|
18ba8d9166 | ||
|
|
e97fd7e81c | ||
|
|
cdb64b0d33 | ||
|
|
8d4d3b03bb | ||
|
|
addefe79e1 | ||
|
|
b764d3b8f6 | ||
|
|
611fd884bd | ||
|
|
826090e099 | ||
|
|
7399de6ecc | ||
|
|
25cb5e7505 | ||
|
|
5c13ec3121 | ||
|
|
d8aff3a7e3 | ||
|
|
f44927b9f8 | ||
|
|
c0110cb5af | ||
|
|
1f8e1142a0 | ||
|
|
1e51de88d6 | ||
|
|
30995b5397 | ||
|
|
eb60f67054 | ||
|
|
78193ceec1 | ||
|
|
f0e08e7687 | ||
|
|
10b8259259 | ||
|
|
6826149a8f | ||
|
|
eb0b77bf4d | ||
|
|
9d81467937 | ||
|
|
fd8ccaf01a | ||
|
|
c9debc50b1 | ||
|
|
2b30e3b6d7 | ||
|
|
6e90ec6111 | ||
|
|
8dd38f4775 | ||
|
|
fbd73f248f | ||
|
|
3fcefe6c32 | ||
|
|
f740d2c291 | ||
|
|
bf6585a40f | ||
|
|
8c2dd7b3f0 | ||
|
|
4167c437a8 | ||
|
|
0ddaef3c9a | ||
|
|
2fc6aaf936 | ||
|
|
1c0519f1c7 | ||
|
|
6bbe7800be | ||
|
|
2694149489 | ||
|
|
a17ac50118 | ||
|
|
656a77d585 | ||
|
|
7455476c60 | ||
|
|
36cda57c81 | ||
|
|
9f1f203b84 | ||
|
|
b41a8ca93f | ||
|
|
e3cf0c0e10 | ||
|
|
de18bce9aa | ||
|
|
3cc407bc0e | ||
|
|
00a0a12138 | ||
|
|
b08767a4f9 | ||
|
|
ac6bde7a98 | ||
|
|
d2d41d68dd | ||
|
|
944b7f7617 | ||
|
|
53825eb073 | ||
|
|
1a7f49513f | ||
|
|
885a2ce7ef | ||
|
|
14ba80a0af | ||
|
|
5fa22fdf82 | ||
|
|
bcaae2eb91 | ||
|
|
767a41e263 | ||
|
|
252d6c5301 | ||
|
|
7a4e65ad4b | ||
|
|
a582aa89a9 | ||
|
|
acefa1da12 | ||
|
|
a88698f3fc | ||
|
|
ebc6755b33 | ||
|
|
c8eff34388 | ||
|
|
f19b03825b | ||
|
|
25178cdbe1 | ||
|
|
a461538d58 | ||
|
|
b43ee62947 | ||
|
|
ebe6f418f3 | ||
|
|
391e79f8ee | ||
|
|
c7fcb7a84b | ||
|
|
87f4ed591e | ||
|
|
440d2e28ed | ||
|
|
6cb8980404 | ||
|
|
fe752bbd35 | ||
|
|
c74d451fa2 | ||
|
|
12d743fb35 | ||
|
|
6acb9f7910 | ||
|
|
eb6f5c6927 | ||
|
|
7ccb4c8ea3 | ||
|
|
4ce986d47d | ||
|
|
91ef085d7d | ||
|
|
97aaa24733 | ||
|
|
faf6441633 | ||
|
|
00c151b463 | ||
|
|
106b20cdbf | ||
|
|
c069b3b1e8 | ||
|
|
a2ae9f1f27 | ||
|
|
4cd6d86426 | ||
|
|
fa72f1947a | ||
|
|
9ee7d3935d | ||
|
|
1071fe0ac7 | ||
|
|
0be003377f | ||
|
|
ca3f497b56 | ||
|
|
034b84b707 | ||
|
|
1624523c4e | ||
|
|
313afe14ce | ||
|
|
01180b316f | ||
|
|
ee7d061001 | ||
|
|
60c5949a74 | ||
|
|
2ebbd4c94d | ||
|
|
785115c62b | ||
|
|
e643fc382c | ||
|
|
34aad82ac3 | ||
|
|
0c29468f90 | ||
|
|
9301dae63e | ||
|
|
2475d4a205 | ||
|
|
be75fc3474 | ||
|
|
785e049af3 | ||
|
|
be4e49e6d7 | ||
|
|
1307d604e7 | ||
|
|
45d57018eb | ||
|
|
03bf348530 | ||
|
|
cab60ef735 | ||
|
|
a3791104f9 | ||
|
|
2b3e40bb2a | ||
|
|
0c1dcad429 | ||
|
|
101ef0cf62 | ||
|
|
0debe0a80c | ||
|
|
d22e62ac8a | ||
|
|
1ee17383f8 | ||
|
|
b59c79c458 | ||
|
|
bcb6444f89 | ||
|
|
c2b14693b4 | ||
|
|
92d35409de | ||
|
|
351a08f813 | ||
|
|
a58dc787a9 | ||
|
|
7079edc2d0 | ||
|
|
da89583ccc | ||
|
|
a42a1f08e9 | ||
|
|
ebd5253e22 | ||
|
|
6411645ffc | ||
|
|
c0c322ba16 | ||
|
|
d35c5cd491 | ||
|
|
7a353028e7 | ||
|
|
2d8d3b7857 | ||
|
|
4190293b07 | ||
|
|
421b4c0aff | ||
|
|
cd69a7cb85 | ||
|
|
0c9ba9e86c | ||
|
|
1b4d2a41c9 | ||
|
|
0787d2b47a | ||
|
|
97bf1d85ab | ||
|
|
207a493fab | ||
|
|
1f3f9e131e | ||
|
|
4ddedfaaf9 | ||
|
|
3ebebef95f | ||
|
|
9f7ad47598 | ||
|
|
3c83cd8be2 | ||
|
|
963b3b768c | ||
|
|
f6709fb5d6 | ||
|
|
921599948b | ||
|
|
5df3cafa99 | ||
|
|
1a2143c1fe | ||
|
|
dd25281305 | ||
|
|
49d0301dde | ||
|
|
d90e56eb45 | ||
|
|
838ada8864 | ||
|
|
65a106792a | ||
|
|
ee4bfcbb81 | ||
|
|
a087f089b8 | ||
|
|
afbe8bf001 | ||
|
|
2a3ef0be06 | ||
|
|
3403909354 | ||
|
|
005d0c5f53 | ||
|
|
8aaaeb29cc | ||
|
|
230f8abd04 | ||
|
|
a18bbb5f2f | ||
|
|
60fce4f1dc | ||
|
|
9af65efcdb | ||
|
|
bc194a7d8c | ||
|
|
c28f691f32 | ||
|
|
ff1f114989 | ||
|
|
cac230206d | ||
|
|
79ae15d5e8 | ||
|
|
0cce0a8877 | ||
|
|
225fd035ae | ||
|
|
fb7d1346b5 | ||
|
|
491a744481 | ||
|
|
f366026435 | ||
|
|
1a0d4ed668 | ||
|
|
63a8c76946 | ||
|
|
f355a68bc9 | ||
|
|
c87e6526c1 | ||
|
|
af3a5076d6 | ||
|
|
18f2e21414 | ||
|
|
8a8cdeebb4 | ||
|
|
12b33f4ea4 | ||
|
|
01b3a09d7d | ||
|
|
0d6c1c7790 | ||
|
|
95e366b6c6 | ||
|
|
77701143bf | ||
|
|
02dea7b09b | ||
|
|
c26f93c4a0 | ||
|
|
c826ac28ef | ||
|
|
1893b0eb30 | ||
|
|
05527b13db | ||
|
|
ae5d9c8bfc | ||
|
|
9117c2a4ec | ||
|
|
bab4bb9904 | ||
|
|
33bae6f49b | ||
|
|
32d619a56b | ||
|
|
642432cf2a | ||
|
|
61e9598b08 | ||
|
|
d4e34c7514 | ||
|
|
bfe7a5e452 | ||
|
|
77d916ffec | ||
|
|
831abf7977 | ||
|
|
817a491087 | ||
|
|
9a8dacc514 | ||
|
|
8adf80d98b | ||
|
|
62686a6213 | ||
|
|
3a089242f8 | ||
|
|
9d70c38504 | ||
|
|
aeb464f3ca | ||
|
|
7076717b20 | ||
|
|
c0a4fcea0a | ||
|
|
aa2b195c86 | ||
|
|
1d0872e7ca | ||
|
|
33988637b5 | ||
|
|
d4f6ad7225 | ||
|
|
078fefed03 | ||
|
|
5b10af85b4 | ||
|
|
4caf95e5dd | ||
|
|
8e1bcf53bb | ||
|
|
064f9be7e4 | ||
|
|
adcfb44cb7 | ||
|
|
3d79773ba2 | ||
|
|
6aa8cbbf20 | ||
|
|
742e73c9c2 | ||
|
|
f8de2bdedc | ||
|
|
59879b7fa7 | ||
|
|
27abae21b8 | ||
|
|
0819c8a51a | ||
|
|
9dcd3cd491 | ||
|
|
49767cccd2 | ||
|
|
29fb447daa | ||
|
|
f6fe5b552d | ||
|
|
bd0801a887 | ||
|
|
05b1c66aa8 | ||
|
|
80ae592c23 | ||
|
|
ba6de4c4d4 | ||
|
|
46ea9170cb | ||
|
|
7d318aeefa | ||
|
|
0aa3cf677a | ||
|
|
72961c5858 | ||
|
|
a05711a37a | ||
|
|
efc9e1d673 |
6
.github/workflows/backend-ci.yml
vendored
6
.github/workflows/backend-ci.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: Unit tests
|
||||
working-directory: backend
|
||||
run: make test-unit
|
||||
@@ -38,10 +38,10 @@ jobs:
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.7
|
||||
version: v2.9
|
||||
args: --timeout=30m
|
||||
working-directory: backend
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -115,7 +115,7 @@ jobs:
|
||||
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
|
||||
# Docker setup for GoReleaser
|
||||
- name: Set up QEMU
|
||||
|
||||
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: Run govulncheck
|
||||
working-directory: backend
|
||||
run: |
|
||||
|
||||
105
AGENTS.md
105
AGENTS.md
@@ -1,105 +0,0 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output.
|
||||
- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`.
|
||||
- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`).
|
||||
- `openspec/`: Spec-driven change docs (`changes/<id>/{proposal,design,tasks}.md`).
|
||||
- `tools/`: Utility scripts (security/perf checks).
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
```bash
|
||||
make build # Build backend + frontend
|
||||
make test # Backend tests + frontend lint/typecheck
|
||||
cd backend && make build # Build backend binary
|
||||
cd backend && make test-unit # Go unit tests
|
||||
cd backend && make test-integration # Go integration tests
|
||||
cd backend && make test # go test ./... + golangci-lint
|
||||
cd frontend && pnpm install --frozen-lockfile
|
||||
cd frontend && pnpm dev # Vite dev server
|
||||
cd frontend && pnpm build # Type-check + production build
|
||||
cd frontend && pnpm test:run # Vitest run
|
||||
cd frontend && pnpm test:coverage # Vitest + coverage report
|
||||
python3 tools/secret_scan.py # Secret scan
|
||||
```
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`).
|
||||
- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard).
|
||||
- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`.
|
||||
- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`.
|
||||
|
||||
## Go & Frontend Development Standards
|
||||
- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears.
|
||||
- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency.
|
||||
- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable.
|
||||
|
||||
### Go Performance Rules
|
||||
- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code.
|
||||
- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable.
|
||||
- Every external I/O path must use `context.Context` with explicit timeout/cancel.
|
||||
- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources.
|
||||
- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`).
|
||||
- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`.
|
||||
- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths.
|
||||
- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after).
|
||||
- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out.
|
||||
- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention.
|
||||
- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible.
|
||||
- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary.
|
||||
- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions.
|
||||
- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain.
|
||||
- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible.
|
||||
- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled.
|
||||
|
||||
### Data Access & Cache Rules
|
||||
- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR.
|
||||
- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints.
|
||||
- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths.
|
||||
- Keep transactions short; never perform external RPC/network calls inside DB transactions.
|
||||
- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`).
|
||||
- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks.
|
||||
- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd.
|
||||
- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths.
|
||||
|
||||
### Frontend Performance Rules
|
||||
- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components.
|
||||
- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs.
|
||||
- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it.
|
||||
- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed.
|
||||
- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows).
|
||||
- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking.
|
||||
- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review.
|
||||
- Avoid expensive inline computations in templates; move to cached `computed` selectors.
|
||||
- Keep state normalized; avoid duplicated derived state across multiple stores/components.
|
||||
- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import.
|
||||
- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`.
|
||||
- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks.
|
||||
- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes.
|
||||
|
||||
### Performance Budget & PR Evidence
|
||||
- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline.
|
||||
- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval.
|
||||
- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table).
|
||||
- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output).
|
||||
|
||||
### Quality Gate
|
||||
- Any changed code must include new or updated unit tests.
|
||||
- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules).
|
||||
- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description.
|
||||
|
||||
## Testing Guidelines
|
||||
- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant.
|
||||
- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`.
|
||||
- Enforce unit-test and coverage rules defined in `Quality Gate`.
|
||||
- Before opening a PR, run `make test` plus targeted tests for touched areas.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`.
|
||||
- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes.
|
||||
- For behavior/API changes, add or update `openspec/changes/...` artifacts.
|
||||
- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR.
|
||||
|
||||
## Security & Configuration Tips
|
||||
- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials.
|
||||
- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev.
|
||||
22
Dockerfile
22
Dockerfile
@@ -7,8 +7,9 @@
|
||||
# =============================================================================
|
||||
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.25.7-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
||||
./cmd/server
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 3: Final Runtime Image
|
||||
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 4: Final Runtime Image
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
@@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||
# This ensures version consistency between backup tools and the database server
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
@@ -5,7 +5,12 @@
|
||||
# It only packages the pre-built binary, no compilation needed.
|
||||
# =============================================================================
|
||||
|
||||
FROM alpine:3.19
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
LABEL description="Sub2API - AI API Gateway Platform"
|
||||
@@ -16,8 +21,20 @@ RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
|
||||
# restore work in the runtime container without requiring Docker socket access.
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
38
README.md
38
README.md
@@ -39,6 +39,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
- **Concurrency Control** - Per-user and per-account concurrency limits
|
||||
- **Rate Limiting** - Configurable request and token rate limits
|
||||
- **Admin Dashboard** - Web interface for monitoring and management
|
||||
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||
|
||||
## Ecosystem
|
||||
|
||||
Community projects that extend or integrate with Sub2API:
|
||||
|
||||
| Project | Description | Features |
|
||||
|---------|-------------|----------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
|
||||
|
||||
## Tech Stack
|
||||
|
||||
@@ -150,14 +160,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# Start services
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# View logs
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**What the script does:**
|
||||
- Downloads `docker-compose.local.yml` and `.env.example`
|
||||
- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example`
|
||||
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
|
||||
- Creates `.env` file with auto-generated secrets
|
||||
- Creates data directories (uses local directories for easy backup/migration)
|
||||
@@ -522,6 +532,28 @@ sub2api/
|
||||
└── install.sh # One-click installation script
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
> **Please read carefully before using this project:**
|
||||
>
|
||||
> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user.
|
||||
>
|
||||
> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project.
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
40
README_CN.md
40
README_CN.md
@@ -39,6 +39,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
- **并发控制** - 用户级和账号级并发限制
|
||||
- **速率限制** - 可配置的请求和 Token 速率限制
|
||||
- **管理后台** - Web 界面进行监控和管理
|
||||
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||
|
||||
## 生态项目
|
||||
|
||||
围绕 Sub2API 的社区扩展与集成项目:
|
||||
|
||||
| 项目 | 说明 | 功能 |
|
||||
|------|------|------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
|
||||
|
||||
## 技术栈
|
||||
|
||||
@@ -137,8 +147,6 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
|
||||
|
||||
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
|
||||
|
||||
如果你的服务器是 **Ubuntu 24.04**,建议直接参考:`deploy/ubuntu24-docker-compose-aicodex.md`,其中包含「安装最新版 Docker + docker-compose-aicodex.yml 部署」的完整步骤。
|
||||
|
||||
#### 前置条件
|
||||
|
||||
- Docker 20.10+
|
||||
@@ -156,14 +164,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# 启动服务
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# 查看日志
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**脚本功能:**
|
||||
- 下载 `docker-compose.local.yml` 和 `.env.example`
|
||||
- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml`)和 `.env.example`
|
||||
- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
|
||||
- 创建 `.env` 文件并填充自动生成的密钥
|
||||
- 创建数据目录(使用本地目录,便于备份和迁移)
|
||||
@@ -590,6 +598,28 @@ sub2api/
|
||||
└── install.sh # 一键安装脚本
|
||||
```
|
||||
|
||||
## 免责声明
|
||||
|
||||
> **使用本项目前请仔细阅读:**
|
||||
>
|
||||
> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。
|
||||
>
|
||||
> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
@@ -93,20 +93,13 @@ linters:
|
||||
check-escaping-errors: true
|
||||
staticcheck:
|
||||
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
|
||||
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
|
||||
dot-import-whitelist:
|
||||
- fmt
|
||||
# https://staticcheck.dev/docs/configuration/options/#initialisms
|
||||
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
|
||||
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
|
||||
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
|
||||
# Default: ["200", "400", "404", "500"]
|
||||
http-status-code-whitelist: [ "200", "400", "404", "500" ]
|
||||
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
|
||||
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
|
||||
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
|
||||
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
|
||||
# Temporarily disable style checks to allow CI to pass
|
||||
# "all" enables every SA/ST/S/QF check; only list the ones to disable.
|
||||
checks:
|
||||
- all
|
||||
- -ST1000 # Package comment format
|
||||
@@ -114,489 +107,19 @@ linters:
|
||||
- -ST1020 # Comment on exported method format
|
||||
- -ST1021 # Comment on exported type format
|
||||
- -ST1022 # Comment on exported variable format
|
||||
# Invalid regular expression.
|
||||
# https://staticcheck.dev/docs/checks/#SA1000
|
||||
- SA1000
|
||||
# Invalid template.
|
||||
# https://staticcheck.dev/docs/checks/#SA1001
|
||||
- SA1001
|
||||
# Invalid format in 'time.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1002
|
||||
- SA1002
|
||||
# Unsupported argument to functions in 'encoding/binary'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1003
|
||||
- SA1003
|
||||
# Suspiciously small untyped constant in 'time.Sleep'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1004
|
||||
- SA1004
|
||||
# Invalid first argument to 'exec.Command'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1005
|
||||
- SA1005
|
||||
# 'Printf' with dynamic first argument and no further arguments.
|
||||
# https://staticcheck.dev/docs/checks/#SA1006
|
||||
- SA1006
|
||||
# Invalid URL in 'net/url.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1007
|
||||
- SA1007
|
||||
# Non-canonical key in 'http.Header' map.
|
||||
# https://staticcheck.dev/docs/checks/#SA1008
|
||||
- SA1008
|
||||
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
|
||||
# https://staticcheck.dev/docs/checks/#SA1010
|
||||
- SA1010
|
||||
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
|
||||
# https://staticcheck.dev/docs/checks/#SA1011
|
||||
- SA1011
|
||||
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA1012
|
||||
- SA1012
|
||||
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
|
||||
# https://staticcheck.dev/docs/checks/#SA1013
|
||||
- SA1013
|
||||
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1014
|
||||
- SA1014
|
||||
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1015
|
||||
- SA1015
|
||||
# Trapping a signal that cannot be trapped.
|
||||
# https://staticcheck.dev/docs/checks/#SA1016
|
||||
- SA1016
|
||||
# Channels used with 'os/signal.Notify' should be buffered.
|
||||
# https://staticcheck.dev/docs/checks/#SA1017
|
||||
- SA1017
|
||||
# 'strings.Replace' called with 'n == 0', which does nothing.
|
||||
# https://staticcheck.dev/docs/checks/#SA1018
|
||||
- SA1018
|
||||
# Using a deprecated function, variable, constant or field.
|
||||
# https://staticcheck.dev/docs/checks/#SA1019
|
||||
- SA1019
|
||||
# Using an invalid host:port pair with a 'net.Listen'-related function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1020
|
||||
- SA1020
|
||||
# Using 'bytes.Equal' to compare two 'net.IP'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1021
|
||||
- SA1021
|
||||
# Modifying the buffer in an 'io.Writer' implementation.
|
||||
# https://staticcheck.dev/docs/checks/#SA1023
|
||||
- SA1023
|
||||
# A string cutset contains duplicate characters.
|
||||
# https://staticcheck.dev/docs/checks/#SA1024
|
||||
- SA1024
|
||||
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
|
||||
# https://staticcheck.dev/docs/checks/#SA1025
|
||||
- SA1025
|
||||
# Cannot marshal channels or functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1026
|
||||
- SA1026
|
||||
# Atomic access to 64-bit variable must be 64-bit aligned.
|
||||
# https://staticcheck.dev/docs/checks/#SA1027
|
||||
- SA1027
|
||||
# 'sort.Slice' can only be used on slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA1028
|
||||
- SA1028
|
||||
# Inappropriate key in call to 'context.WithValue'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1029
|
||||
- SA1029
|
||||
# Invalid argument in call to a 'strconv' function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1030
|
||||
- SA1030
|
||||
# Overlapping byte slices passed to an encoder.
|
||||
# https://staticcheck.dev/docs/checks/#SA1031
|
||||
- SA1031
|
||||
# Wrong order of arguments to 'errors.Is'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1032
|
||||
- SA1032
|
||||
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
|
||||
# https://staticcheck.dev/docs/checks/#SA2000
|
||||
- SA2000
|
||||
# Empty critical section, did you mean to defer the unlock?.
|
||||
# https://staticcheck.dev/docs/checks/#SA2001
|
||||
- SA2001
|
||||
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
|
||||
# https://staticcheck.dev/docs/checks/#SA2002
|
||||
- SA2002
|
||||
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA2003
|
||||
- SA2003
|
||||
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
|
||||
# https://staticcheck.dev/docs/checks/#SA3000
|
||||
- SA3000
|
||||
# Assigning to 'b.N' in benchmarks distorts the results.
|
||||
# https://staticcheck.dev/docs/checks/#SA3001
|
||||
- SA3001
|
||||
# Binary operator has identical expressions on both sides.
|
||||
# https://staticcheck.dev/docs/checks/#SA4000
|
||||
- SA4000
|
||||
# '&*x' gets simplified to 'x', it does not copy 'x'.
|
||||
# https://staticcheck.dev/docs/checks/#SA4001
|
||||
- SA4001
|
||||
# Comparing unsigned values against negative values is pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4003
|
||||
- SA4003
|
||||
# The loop exits unconditionally after one iteration.
|
||||
# https://staticcheck.dev/docs/checks/#SA4004
|
||||
- SA4004
|
||||
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4005
|
||||
- SA4005
|
||||
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4006
|
||||
- SA4006
|
||||
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4008
|
||||
- SA4008
|
||||
# A function argument is overwritten before its first use.
|
||||
# https://staticcheck.dev/docs/checks/#SA4009
|
||||
- SA4009
|
||||
# The result of 'append' will never be observed anywhere.
|
||||
# https://staticcheck.dev/docs/checks/#SA4010
|
||||
- SA4010
|
||||
# Break statement with no effect. Did you mean to break out of an outer loop?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4011
|
||||
- SA4011
|
||||
# Comparing a value against NaN even though no value is equal to NaN.
|
||||
# https://staticcheck.dev/docs/checks/#SA4012
|
||||
- SA4012
|
||||
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
|
||||
# https://staticcheck.dev/docs/checks/#SA4013
|
||||
- SA4013
|
||||
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
|
||||
# https://staticcheck.dev/docs/checks/#SA4014
|
||||
- SA4014
|
||||
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4015
|
||||
- SA4015
|
||||
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4016
|
||||
- SA4016
|
||||
# Discarding the return values of a function without side effects, making the call pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4017
|
||||
- SA4017
|
||||
# Self-assignment of variables.
|
||||
# https://staticcheck.dev/docs/checks/#SA4018
|
||||
- SA4018
|
||||
# Multiple, identical build constraints in the same file.
|
||||
# https://staticcheck.dev/docs/checks/#SA4019
|
||||
- SA4019
|
||||
# Unreachable case clause in a type switch.
|
||||
# https://staticcheck.dev/docs/checks/#SA4020
|
||||
- SA4020
|
||||
# "x = append(y)" is equivalent to "x = y".
|
||||
# https://staticcheck.dev/docs/checks/#SA4021
|
||||
- SA4021
|
||||
# Comparing the address of a variable against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4022
|
||||
- SA4022
|
||||
# Impossible comparison of interface value with untyped nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4023
|
||||
- SA4023
|
||||
# Checking for impossible return value from a builtin function.
|
||||
# https://staticcheck.dev/docs/checks/#SA4024
|
||||
- SA4024
|
||||
# Integer division of literals that results in zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4025
|
||||
- SA4025
|
||||
# Go constants cannot express negative zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4026
|
||||
- SA4026
|
||||
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
|
||||
# https://staticcheck.dev/docs/checks/#SA4027
|
||||
- SA4027
|
||||
# 'x % 1' is always zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4028
|
||||
- SA4028
|
||||
# Ineffective attempt at sorting slice.
|
||||
# https://staticcheck.dev/docs/checks/#SA4029
|
||||
- SA4029
|
||||
# Ineffective attempt at generating random number.
|
||||
# https://staticcheck.dev/docs/checks/#SA4030
|
||||
- SA4030
|
||||
# Checking never-nil value against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4031
|
||||
- SA4031
|
||||
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
|
||||
# https://staticcheck.dev/docs/checks/#SA4032
|
||||
- SA4032
|
||||
# Assignment to nil map.
|
||||
# https://staticcheck.dev/docs/checks/#SA5000
|
||||
- SA5000
|
||||
# Deferring 'Close' before checking for a possible error.
|
||||
# https://staticcheck.dev/docs/checks/#SA5001
|
||||
- SA5001
|
||||
# The empty for loop ("for {}") spins and can block the scheduler.
|
||||
# https://staticcheck.dev/docs/checks/#SA5002
|
||||
- SA5002
|
||||
# Defers in infinite loops will never execute.
|
||||
# https://staticcheck.dev/docs/checks/#SA5003
|
||||
- SA5003
|
||||
# "for { select { ..." with an empty default branch spins.
|
||||
# https://staticcheck.dev/docs/checks/#SA5004
|
||||
- SA5004
|
||||
# The finalizer references the finalized object, preventing garbage collection.
|
||||
# https://staticcheck.dev/docs/checks/#SA5005
|
||||
- SA5005
|
||||
# Infinite recursive call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5007
|
||||
- SA5007
|
||||
# Invalid struct tag.
|
||||
# https://staticcheck.dev/docs/checks/#SA5008
|
||||
- SA5008
|
||||
# Invalid Printf call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5009
|
||||
- SA5009
|
||||
# Impossible type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#SA5010
|
||||
- SA5010
|
||||
# Possible nil pointer dereference.
|
||||
# https://staticcheck.dev/docs/checks/#SA5011
|
||||
- SA5011
|
||||
# Passing odd-sized slice to function expecting even size.
|
||||
# https://staticcheck.dev/docs/checks/#SA5012
|
||||
- SA5012
|
||||
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6000
|
||||
- SA6000
|
||||
# Missing an optimization opportunity when indexing maps by byte slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA6001
|
||||
- SA6001
|
||||
# Storing non-pointer values in 'sync.Pool' allocates memory.
|
||||
# https://staticcheck.dev/docs/checks/#SA6002
|
||||
- SA6002
|
||||
# Converting a string to a slice of runes before ranging over it.
|
||||
# https://staticcheck.dev/docs/checks/#SA6003
|
||||
- SA6003
|
||||
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6005
|
||||
- SA6005
|
||||
# Using io.WriteString to write '[]byte'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6006
|
||||
- SA6006
|
||||
# Defers in range loops may not run when you expect them to.
|
||||
# https://staticcheck.dev/docs/checks/#SA9001
|
||||
- SA9001
|
||||
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
|
||||
# https://staticcheck.dev/docs/checks/#SA9002
|
||||
- SA9002
|
||||
# Empty body in an if or else branch.
|
||||
# https://staticcheck.dev/docs/checks/#SA9003
|
||||
- SA9003
|
||||
# Only the first constant has an explicit type.
|
||||
# https://staticcheck.dev/docs/checks/#SA9004
|
||||
- SA9004
|
||||
# Trying to marshal a struct with no public fields nor custom marshaling.
|
||||
# https://staticcheck.dev/docs/checks/#SA9005
|
||||
- SA9005
|
||||
# Dubious bit shifting of a fixed size integer value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9006
|
||||
- SA9006
|
||||
# Deleting a directory that shouldn't be deleted.
|
||||
# https://staticcheck.dev/docs/checks/#SA9007
|
||||
- SA9007
|
||||
# 'else' branch of a type assertion is probably not reading the right value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9008
|
||||
- SA9008
|
||||
# Ineffectual Go compiler directive.
|
||||
# https://staticcheck.dev/docs/checks/#SA9009
|
||||
- SA9009
|
||||
# NOTE: ST1000, ST1001, ST1003, ST1020, ST1021, ST1022 are disabled above
|
||||
# Incorrectly formatted error string.
|
||||
# https://staticcheck.dev/docs/checks/#ST1005
|
||||
- ST1005
|
||||
# Poorly chosen receiver name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1006
|
||||
- ST1006
|
||||
# A function's error value should be its last return value.
|
||||
# https://staticcheck.dev/docs/checks/#ST1008
|
||||
- ST1008
|
||||
# Poorly chosen name for variable of type 'time.Duration'.
|
||||
# https://staticcheck.dev/docs/checks/#ST1011
|
||||
- ST1011
|
||||
# Poorly chosen name for error variable.
|
||||
# https://staticcheck.dev/docs/checks/#ST1012
|
||||
- ST1012
|
||||
# Should use constants for HTTP error codes, not magic numbers.
|
||||
# https://staticcheck.dev/docs/checks/#ST1013
|
||||
- ST1013
|
||||
# A switch's default case should be the first or last case.
|
||||
# https://staticcheck.dev/docs/checks/#ST1015
|
||||
- ST1015
|
||||
# Use consistent method receiver names.
|
||||
# https://staticcheck.dev/docs/checks/#ST1016
|
||||
- ST1016
|
||||
# Don't use Yoda conditions.
|
||||
# https://staticcheck.dev/docs/checks/#ST1017
|
||||
- ST1017
|
||||
# Avoid zero-width and control characters in string literals.
|
||||
# https://staticcheck.dev/docs/checks/#ST1018
|
||||
- ST1018
|
||||
# Importing the same package multiple times.
|
||||
# https://staticcheck.dev/docs/checks/#ST1019
|
||||
- ST1019
|
||||
# NOTE: ST1020, ST1021, ST1022 removed (disabled above)
|
||||
# Redundant type in variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#ST1023
|
||||
- ST1023
|
||||
# Use plain channel send or receive instead of single-case select.
|
||||
# https://staticcheck.dev/docs/checks/#S1000
|
||||
- S1000
|
||||
# Replace for loop with call to copy.
|
||||
# https://staticcheck.dev/docs/checks/#S1001
|
||||
- S1001
|
||||
# Omit comparison with boolean constant.
|
||||
# https://staticcheck.dev/docs/checks/#S1002
|
||||
- S1002
|
||||
# Replace call to 'strings.Index' with 'strings.Contains'.
|
||||
# https://staticcheck.dev/docs/checks/#S1003
|
||||
- S1003
|
||||
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
|
||||
# https://staticcheck.dev/docs/checks/#S1004
|
||||
- S1004
|
||||
# Drop unnecessary use of the blank identifier.
|
||||
# https://staticcheck.dev/docs/checks/#S1005
|
||||
- S1005
|
||||
# Use "for { ... }" for infinite loops.
|
||||
# https://staticcheck.dev/docs/checks/#S1006
|
||||
- S1006
|
||||
# Simplify regular expression by using raw string literal.
|
||||
# https://staticcheck.dev/docs/checks/#S1007
|
||||
- S1007
|
||||
# Simplify returning boolean expression.
|
||||
# https://staticcheck.dev/docs/checks/#S1008
|
||||
- S1008
|
||||
# Omit redundant nil check on slices, maps, and channels.
|
||||
# https://staticcheck.dev/docs/checks/#S1009
|
||||
- S1009
|
||||
# Omit default slice index.
|
||||
# https://staticcheck.dev/docs/checks/#S1010
|
||||
- S1010
|
||||
# Use a single 'append' to concatenate two slices.
|
||||
# https://staticcheck.dev/docs/checks/#S1011
|
||||
- S1011
|
||||
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1012
|
||||
- S1012
|
||||
# Use a type conversion instead of manually copying struct fields.
|
||||
# https://staticcheck.dev/docs/checks/#S1016
|
||||
- S1016
|
||||
# Replace manual trimming with 'strings.TrimPrefix'.
|
||||
# https://staticcheck.dev/docs/checks/#S1017
|
||||
- S1017
|
||||
# Use "copy" for sliding elements.
|
||||
# https://staticcheck.dev/docs/checks/#S1018
|
||||
- S1018
|
||||
# Simplify "make" call by omitting redundant arguments.
|
||||
# https://staticcheck.dev/docs/checks/#S1019
|
||||
- S1019
|
||||
# Omit redundant nil check in type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#S1020
|
||||
- S1020
|
||||
# Merge variable declaration and assignment.
|
||||
# https://staticcheck.dev/docs/checks/#S1021
|
||||
- S1021
|
||||
# Omit redundant control flow.
|
||||
# https://staticcheck.dev/docs/checks/#S1023
|
||||
- S1023
|
||||
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1024
|
||||
- S1024
|
||||
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
|
||||
# https://staticcheck.dev/docs/checks/#S1025
|
||||
- S1025
|
||||
# Simplify error construction with 'fmt.Errorf'.
|
||||
# https://staticcheck.dev/docs/checks/#S1028
|
||||
- S1028
|
||||
# Range over the string directly.
|
||||
# https://staticcheck.dev/docs/checks/#S1029
|
||||
- S1029
|
||||
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
|
||||
# https://staticcheck.dev/docs/checks/#S1030
|
||||
- S1030
|
||||
# Omit redundant nil check around loop.
|
||||
# https://staticcheck.dev/docs/checks/#S1031
|
||||
- S1031
|
||||
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1032
|
||||
- S1032
|
||||
# Unnecessary guard around call to "delete".
|
||||
# https://staticcheck.dev/docs/checks/#S1033
|
||||
- S1033
|
||||
# Use result of type assertion to simplify cases.
|
||||
# https://staticcheck.dev/docs/checks/#S1034
|
||||
- S1034
|
||||
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
|
||||
# https://staticcheck.dev/docs/checks/#S1035
|
||||
- S1035
|
||||
# Unnecessary guard around map access.
|
||||
# https://staticcheck.dev/docs/checks/#S1036
|
||||
- S1036
|
||||
# Elaborate way of sleeping.
|
||||
# https://staticcheck.dev/docs/checks/#S1037
|
||||
- S1037
|
||||
# Unnecessarily complex way of printing formatted string.
|
||||
# https://staticcheck.dev/docs/checks/#S1038
|
||||
- S1038
|
||||
# Unnecessary use of 'fmt.Sprint'.
|
||||
# https://staticcheck.dev/docs/checks/#S1039
|
||||
- S1039
|
||||
# Type assertion to current type.
|
||||
# https://staticcheck.dev/docs/checks/#S1040
|
||||
- S1040
|
||||
# Apply De Morgan's law.
|
||||
# https://staticcheck.dev/docs/checks/#QF1001
|
||||
- QF1001
|
||||
# Convert untagged switch to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1002
|
||||
- QF1002
|
||||
# Convert if/else-if chain to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1003
|
||||
- QF1003
|
||||
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1004
|
||||
- QF1004
|
||||
# Expand call to 'math.Pow'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1005
|
||||
- QF1005
|
||||
# Lift 'if'+'break' into loop condition.
|
||||
# https://staticcheck.dev/docs/checks/#QF1006
|
||||
- QF1006
|
||||
# Merge conditional assignment into variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1007
|
||||
- QF1007
|
||||
# Omit embedded fields from selector expression.
|
||||
# https://staticcheck.dev/docs/checks/#QF1008
|
||||
- QF1008
|
||||
# Use 'time.Time.Equal' instead of '==' operator.
|
||||
# https://staticcheck.dev/docs/checks/#QF1009
|
||||
- QF1009
|
||||
# Convert slice of bytes to string when printing it.
|
||||
# https://staticcheck.dev/docs/checks/#QF1010
|
||||
- QF1010
|
||||
# Omit redundant type from variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1011
|
||||
- QF1011
|
||||
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1012
|
||||
- QF1012
|
||||
unused:
|
||||
# Mark all struct fields that have been written to as used.
|
||||
# Default: true
|
||||
field-writes-are-uses: false
|
||||
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
|
||||
field-writes-are-uses: true
|
||||
# Default: false
|
||||
post-statements-are-reads: true
|
||||
# Mark all exported fields as used.
|
||||
# default: true
|
||||
exported-fields-are-used: false
|
||||
# Mark all function parameters as used.
|
||||
# default: true
|
||||
parameters-are-used: true
|
||||
# Mark all local variables as used.
|
||||
# default: true
|
||||
local-variables-are-used: false
|
||||
# Mark all identifiers inside generated files as used.
|
||||
# Default: true
|
||||
generated-is-used: false
|
||||
exported-fields-are-used: true
|
||||
# Default: true
|
||||
parameters-are-used: true
|
||||
# Default: true
|
||||
local-variables-are-used: false
|
||||
# Default: true — must be true, ent generates 130K+ lines of code
|
||||
generated-is-used: true
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
|
||||
@@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
// Server layer ProviderSet
|
||||
server.ProviderSet,
|
||||
|
||||
// Privacy client factory for OpenAI training opt-out
|
||||
providePrivacyClientFactory,
|
||||
|
||||
// BuildInfo provider
|
||||
provideServiceBuildInfo,
|
||||
|
||||
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
return repository.CreatePrivacyReqClient
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
@@ -86,6 +93,8 @@ func provideCleanup(
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -216,6 +225,18 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ScheduledTestRunnerService", func() error {
|
||||
if scheduledTestRunner != nil {
|
||||
scheduledTestRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
@@ -104,7 +105,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
|
||||
privacyClientFactory := providePrivacyClientFactory()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
@@ -144,6 +146,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
dataManagementService := service.NewDataManagementService()
|
||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||
dbDumper := repository.NewPgDumper(configConfig)
|
||||
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -162,9 +168,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@@ -195,7 +201,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
||||
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
||||
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
|
||||
scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db)
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -222,10 +232,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -240,6 +251,10 @@ type Application struct {
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
return repository.CreatePrivacyReqClient
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
@@ -273,6 +288,8 @@ func provideCleanup(
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -402,6 +419,18 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ScheduledTestRunnerService", func() error {
|
||||
if scheduledTestRunner != nil {
|
||||
scheduledTestRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -74,6 +74,8 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
geminiOAuthSvc,
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -41,6 +41,8 @@ type Account struct {
|
||||
ProxyID *int64 `json:"proxy_id,omitempty"`
|
||||
// Concurrency holds the value of the "concurrency" field.
|
||||
Concurrency int `json:"concurrency,omitempty"`
|
||||
// LoadFactor holds the value of the "load_factor" field.
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
// Priority holds the value of the "priority" field.
|
||||
Priority int `json:"priority,omitempty"`
|
||||
// RateMultiplier holds the value of the "rate_multiplier" field.
|
||||
@@ -143,7 +145,7 @@ func (*Account) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case account.FieldRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldLoadFactor, account.FieldPriority:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -243,6 +245,13 @@ func (_m *Account) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Concurrency = int(value.Int64)
|
||||
}
|
||||
case account.FieldLoadFactor:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field load_factor", values[i])
|
||||
} else if value.Valid {
|
||||
_m.LoadFactor = new(int)
|
||||
*_m.LoadFactor = int(value.Int64)
|
||||
}
|
||||
case account.FieldPriority:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field priority", values[i])
|
||||
@@ -445,6 +454,11 @@ func (_m *Account) String() string {
|
||||
builder.WriteString("concurrency=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Concurrency))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.LoadFactor; v != nil {
|
||||
builder.WriteString("load_factor=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("priority=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -37,6 +37,8 @@ const (
|
||||
FieldProxyID = "proxy_id"
|
||||
// FieldConcurrency holds the string denoting the concurrency field in the database.
|
||||
FieldConcurrency = "concurrency"
|
||||
// FieldLoadFactor holds the string denoting the load_factor field in the database.
|
||||
FieldLoadFactor = "load_factor"
|
||||
// FieldPriority holds the string denoting the priority field in the database.
|
||||
FieldPriority = "priority"
|
||||
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
||||
@@ -121,6 +123,7 @@ var Columns = []string{
|
||||
FieldExtra,
|
||||
FieldProxyID,
|
||||
FieldConcurrency,
|
||||
FieldLoadFactor,
|
||||
FieldPriority,
|
||||
FieldRateMultiplier,
|
||||
FieldStatus,
|
||||
@@ -250,6 +253,11 @@ func ByConcurrency(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldConcurrency, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByLoadFactor orders the results by the load_factor field.
|
||||
func ByLoadFactor(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldLoadFactor, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPriority orders the results by the priority field.
|
||||
func ByPriority(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPriority, opts...).ToFunc()
|
||||
|
||||
@@ -100,6 +100,11 @@ func Concurrency(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldConcurrency, v))
|
||||
}
|
||||
|
||||
// LoadFactor applies equality check predicate on the "load_factor" field. It's identical to LoadFactorEQ.
|
||||
func LoadFactor(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ.
|
||||
func Priority(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||
@@ -650,6 +655,56 @@ func ConcurrencyLTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldConcurrency, v))
|
||||
}
|
||||
|
||||
// LoadFactorEQ applies the EQ predicate on the "load_factor" field.
|
||||
func LoadFactorEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorNEQ applies the NEQ predicate on the "load_factor" field.
|
||||
func LoadFactorNEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldNEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorIn applies the In predicate on the "load_factor" field.
|
||||
func LoadFactorIn(vs ...int) predicate.Account {
|
||||
return predicate.Account(sql.FieldIn(FieldLoadFactor, vs...))
|
||||
}
|
||||
|
||||
// LoadFactorNotIn applies the NotIn predicate on the "load_factor" field.
|
||||
func LoadFactorNotIn(vs ...int) predicate.Account {
|
||||
return predicate.Account(sql.FieldNotIn(FieldLoadFactor, vs...))
|
||||
}
|
||||
|
||||
// LoadFactorGT applies the GT predicate on the "load_factor" field.
|
||||
func LoadFactorGT(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldGT(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorGTE applies the GTE predicate on the "load_factor" field.
|
||||
func LoadFactorGTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldGTE(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorLT applies the LT predicate on the "load_factor" field.
|
||||
func LoadFactorLT(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLT(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorLTE applies the LTE predicate on the "load_factor" field.
|
||||
func LoadFactorLTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorIsNil applies the IsNil predicate on the "load_factor" field.
|
||||
func LoadFactorIsNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldIsNull(FieldLoadFactor))
|
||||
}
|
||||
|
||||
// LoadFactorNotNil applies the NotNil predicate on the "load_factor" field.
|
||||
func LoadFactorNotNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldNotNull(FieldLoadFactor))
|
||||
}
|
||||
|
||||
// PriorityEQ applies the EQ predicate on the "priority" field.
|
||||
func PriorityEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||
|
||||
@@ -139,6 +139,20 @@ func (_c *AccountCreate) SetNillableConcurrency(v *int) *AccountCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_c *AccountCreate) SetLoadFactor(v int) *AccountCreate {
|
||||
_c.mutation.SetLoadFactor(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_c *AccountCreate) SetNillableLoadFactor(v *int) *AccountCreate {
|
||||
if v != nil {
|
||||
_c.SetLoadFactor(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_c *AccountCreate) SetPriority(v int) *AccountCreate {
|
||||
_c.mutation.SetPriority(v)
|
||||
@@ -623,6 +637,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(account.FieldConcurrency, field.TypeInt, value)
|
||||
_node.Concurrency = value
|
||||
}
|
||||
if value, ok := _c.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
_node.LoadFactor = &value
|
||||
}
|
||||
if value, ok := _c.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
_node.Priority = value
|
||||
@@ -936,6 +954,30 @@ func (u *AccountUpsert) AddConcurrency(v int) *AccountUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsert) SetLoadFactor(v int) *AccountUpsert {
|
||||
u.Set(account.FieldLoadFactor, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsert) UpdateLoadFactor() *AccountUpsert {
|
||||
u.SetExcluded(account.FieldLoadFactor)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsert) AddLoadFactor(v int) *AccountUpsert {
|
||||
u.Add(account.FieldLoadFactor, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsert) ClearLoadFactor() *AccountUpsert {
|
||||
u.SetNull(account.FieldLoadFactor)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsert) SetPriority(v int) *AccountUpsert {
|
||||
u.Set(account.FieldPriority, v)
|
||||
@@ -1419,6 +1461,34 @@ func (u *AccountUpsertOne) UpdateConcurrency() *AccountUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsertOne) SetLoadFactor(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsertOne) AddLoadFactor(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.AddLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsertOne) UpdateLoadFactor() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsertOne) ClearLoadFactor() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.ClearLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsertOne) SetPriority(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
@@ -2113,6 +2183,34 @@ func (u *AccountUpsertBulk) UpdateConcurrency() *AccountUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) SetLoadFactor(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) AddLoadFactor(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.AddLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsertBulk) UpdateLoadFactor() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) ClearLoadFactor() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.ClearLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsertBulk) SetPriority(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
|
||||
@@ -172,6 +172,33 @@ func (_u *AccountUpdate) AddConcurrency(v int) *AccountUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_u *AccountUpdate) SetLoadFactor(v int) *AccountUpdate {
|
||||
_u.mutation.ResetLoadFactor()
|
||||
_u.mutation.SetLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_u *AccountUpdate) SetNillableLoadFactor(v *int) *AccountUpdate {
|
||||
if v != nil {
|
||||
_u.SetLoadFactor(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds value to the "load_factor" field.
|
||||
func (_u *AccountUpdate) AddLoadFactor(v int) *AccountUpdate {
|
||||
_u.mutation.AddLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (_u *AccountUpdate) ClearLoadFactor() *AccountUpdate {
|
||||
_u.mutation.ClearLoadFactor()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_u *AccountUpdate) SetPriority(v int) *AccountUpdate {
|
||||
_u.mutation.ResetPriority()
|
||||
@@ -684,6 +711,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedConcurrency(); ok {
|
||||
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedLoadFactor(); ok {
|
||||
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.LoadFactorCleared() {
|
||||
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
@@ -1063,6 +1099,33 @@ func (_u *AccountUpdateOne) AddConcurrency(v int) *AccountUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) SetLoadFactor(v int) *AccountUpdateOne {
|
||||
_u.mutation.ResetLoadFactor()
|
||||
_u.mutation.SetLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_u *AccountUpdateOne) SetNillableLoadFactor(v *int) *AccountUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetLoadFactor(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds value to the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) AddLoadFactor(v int) *AccountUpdateOne {
|
||||
_u.mutation.AddLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) ClearLoadFactor() *AccountUpdateOne {
|
||||
_u.mutation.ClearLoadFactor()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_u *AccountUpdateOne) SetPriority(v int) *AccountUpdateOne {
|
||||
_u.mutation.ResetPriority()
|
||||
@@ -1605,6 +1668,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
|
||||
if value, ok := _u.mutation.AddedConcurrency(); ok {
|
||||
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedLoadFactor(); ok {
|
||||
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.LoadFactorCleared() {
|
||||
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ type Announcement struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
// 状态: draft, active, archived
|
||||
Status string `json:"status,omitempty"`
|
||||
// 通知模式: silent(仅铃铛), popup(弹窗提醒)
|
||||
NotifyMode string `json:"notify_mode,omitempty"`
|
||||
// 展示条件(JSON 规则)
|
||||
Targeting domain.AnnouncementTargeting `json:"targeting,omitempty"`
|
||||
// 开始展示时间(为空表示立即生效)
|
||||
@@ -72,7 +74,7 @@ func (*Announcement) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new([]byte)
|
||||
case announcement.FieldID, announcement.FieldCreatedBy, announcement.FieldUpdatedBy:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus:
|
||||
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus, announcement.FieldNotifyMode:
|
||||
values[i] = new(sql.NullString)
|
||||
case announcement.FieldStartsAt, announcement.FieldEndsAt, announcement.FieldCreatedAt, announcement.FieldUpdatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -115,6 +117,12 @@ func (_m *Announcement) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Status = value.String
|
||||
}
|
||||
case announcement.FieldNotifyMode:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field notify_mode", values[i])
|
||||
} else if value.Valid {
|
||||
_m.NotifyMode = value.String
|
||||
}
|
||||
case announcement.FieldTargeting:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field targeting", values[i])
|
||||
@@ -213,6 +221,9 @@ func (_m *Announcement) String() string {
|
||||
builder.WriteString("status=")
|
||||
builder.WriteString(_m.Status)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("notify_mode=")
|
||||
builder.WriteString(_m.NotifyMode)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("targeting=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Targeting))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -20,6 +20,8 @@ const (
|
||||
FieldContent = "content"
|
||||
// FieldStatus holds the string denoting the status field in the database.
|
||||
FieldStatus = "status"
|
||||
// FieldNotifyMode holds the string denoting the notify_mode field in the database.
|
||||
FieldNotifyMode = "notify_mode"
|
||||
// FieldTargeting holds the string denoting the targeting field in the database.
|
||||
FieldTargeting = "targeting"
|
||||
// FieldStartsAt holds the string denoting the starts_at field in the database.
|
||||
@@ -53,6 +55,7 @@ var Columns = []string{
|
||||
FieldTitle,
|
||||
FieldContent,
|
||||
FieldStatus,
|
||||
FieldNotifyMode,
|
||||
FieldTargeting,
|
||||
FieldStartsAt,
|
||||
FieldEndsAt,
|
||||
@@ -81,6 +84,10 @@ var (
|
||||
DefaultStatus string
|
||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
StatusValidator func(string) error
|
||||
// DefaultNotifyMode holds the default value on creation for the "notify_mode" field.
|
||||
DefaultNotifyMode string
|
||||
// NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save.
|
||||
NotifyModeValidator func(string) error
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
@@ -112,6 +119,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByNotifyMode orders the results by the notify_mode field.
|
||||
func ByNotifyMode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldNotifyMode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByStartsAt orders the results by the starts_at field.
|
||||
func ByStartsAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStartsAt, opts...).ToFunc()
|
||||
|
||||
@@ -70,6 +70,11 @@ func Status(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
|
||||
}
|
||||
|
||||
// NotifyMode applies equality check predicate on the "notify_mode" field. It's identical to NotifyModeEQ.
|
||||
func NotifyMode(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ.
|
||||
func StartsAt(v time.Time) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
|
||||
@@ -295,6 +300,71 @@ func StatusContainsFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContainsFold(FieldStatus, v))
|
||||
}
|
||||
|
||||
// NotifyModeEQ applies the EQ predicate on the "notify_mode" field.
|
||||
func NotifyModeEQ(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeNEQ applies the NEQ predicate on the "notify_mode" field.
|
||||
func NotifyModeNEQ(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldNEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeIn applies the In predicate on the "notify_mode" field.
|
||||
func NotifyModeIn(vs ...string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldIn(FieldNotifyMode, vs...))
|
||||
}
|
||||
|
||||
// NotifyModeNotIn applies the NotIn predicate on the "notify_mode" field.
|
||||
func NotifyModeNotIn(vs ...string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldNotIn(FieldNotifyMode, vs...))
|
||||
}
|
||||
|
||||
// NotifyModeGT applies the GT predicate on the "notify_mode" field.
|
||||
func NotifyModeGT(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldGT(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeGTE applies the GTE predicate on the "notify_mode" field.
|
||||
func NotifyModeGTE(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldGTE(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeLT applies the LT predicate on the "notify_mode" field.
|
||||
func NotifyModeLT(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldLT(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeLTE applies the LTE predicate on the "notify_mode" field.
|
||||
func NotifyModeLTE(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldLTE(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeContains applies the Contains predicate on the "notify_mode" field.
|
||||
func NotifyModeContains(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContains(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeHasPrefix applies the HasPrefix predicate on the "notify_mode" field.
|
||||
func NotifyModeHasPrefix(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldHasPrefix(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeHasSuffix applies the HasSuffix predicate on the "notify_mode" field.
|
||||
func NotifyModeHasSuffix(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldHasSuffix(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeEqualFold applies the EqualFold predicate on the "notify_mode" field.
|
||||
func NotifyModeEqualFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEqualFold(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeContainsFold applies the ContainsFold predicate on the "notify_mode" field.
|
||||
func NotifyModeContainsFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContainsFold(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// TargetingIsNil applies the IsNil predicate on the "targeting" field.
|
||||
func TargetingIsNil() predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldIsNull(FieldTargeting))
|
||||
|
||||
@@ -50,6 +50,20 @@ func (_c *AnnouncementCreate) SetNillableStatus(v *string) *AnnouncementCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_c *AnnouncementCreate) SetNotifyMode(v string) *AnnouncementCreate {
|
||||
_c.mutation.SetNotifyMode(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_c *AnnouncementCreate) SetNillableNotifyMode(v *string) *AnnouncementCreate {
|
||||
if v != nil {
|
||||
_c.SetNotifyMode(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_c *AnnouncementCreate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementCreate {
|
||||
_c.mutation.SetTargeting(v)
|
||||
@@ -202,6 +216,10 @@ func (_c *AnnouncementCreate) defaults() {
|
||||
v := announcement.DefaultStatus
|
||||
_c.mutation.SetStatus(v)
|
||||
}
|
||||
if _, ok := _c.mutation.NotifyMode(); !ok {
|
||||
v := announcement.DefaultNotifyMode
|
||||
_c.mutation.SetNotifyMode(v)
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
v := announcement.DefaultCreatedAt()
|
||||
_c.mutation.SetCreatedAt(v)
|
||||
@@ -238,6 +256,14 @@ func (_c *AnnouncementCreate) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.NotifyMode(); !ok {
|
||||
return &ValidationError{Name: "notify_mode", err: errors.New(`ent: missing required field "Announcement.notify_mode"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Announcement.created_at"`)}
|
||||
}
|
||||
@@ -283,6 +309,10 @@ func (_c *AnnouncementCreate) createSpec() (*Announcement, *sqlgraph.CreateSpec)
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
_node.Status = value
|
||||
}
|
||||
if value, ok := _c.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
_node.NotifyMode = value
|
||||
}
|
||||
if value, ok := _c.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
_node.Targeting = value
|
||||
@@ -415,6 +445,18 @@ func (u *AnnouncementUpsert) UpdateStatus() *AnnouncementUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsert) SetNotifyMode(v string) *AnnouncementUpsert {
|
||||
u.Set(announcement.FieldNotifyMode, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsert) UpdateNotifyMode() *AnnouncementUpsert {
|
||||
u.SetExcluded(announcement.FieldNotifyMode)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsert) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsert {
|
||||
u.Set(announcement.FieldTargeting, v)
|
||||
@@ -616,6 +658,20 @@ func (u *AnnouncementUpsertOne) UpdateStatus() *AnnouncementUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsertOne) SetNotifyMode(v string) *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.SetNotifyMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsertOne) UpdateNotifyMode() *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.UpdateNotifyMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsertOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
@@ -1002,6 +1058,20 @@ func (u *AnnouncementUpsertBulk) UpdateStatus() *AnnouncementUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsertBulk) SetNotifyMode(v string) *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.SetNotifyMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsertBulk) UpdateNotifyMode() *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.UpdateNotifyMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsertBulk) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
|
||||
@@ -72,6 +72,20 @@ func (_u *AnnouncementUpdate) SetNillableStatus(v *string) *AnnouncementUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_u *AnnouncementUpdate) SetNotifyMode(v string) *AnnouncementUpdate {
|
||||
_u.mutation.SetNotifyMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_u *AnnouncementUpdate) SetNillableNotifyMode(v *string) *AnnouncementUpdate {
|
||||
if v != nil {
|
||||
_u.SetNotifyMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_u *AnnouncementUpdate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdate {
|
||||
_u.mutation.SetTargeting(v)
|
||||
@@ -286,6 +300,11 @@ func (_u *AnnouncementUpdate) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -310,6 +329,9 @@ func (_u *AnnouncementUpdate) sqlSave(ctx context.Context) (_node int, err error
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
}
|
||||
@@ -456,6 +478,20 @@ func (_u *AnnouncementUpdateOne) SetNillableStatus(v *string) *AnnouncementUpdat
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_u *AnnouncementUpdateOne) SetNotifyMode(v string) *AnnouncementUpdateOne {
|
||||
_u.mutation.SetNotifyMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_u *AnnouncementUpdateOne) SetNillableNotifyMode(v *string) *AnnouncementUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetNotifyMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_u *AnnouncementUpdateOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdateOne {
|
||||
_u.mutation.SetTargeting(v)
|
||||
@@ -683,6 +719,11 @@ func (_u *AnnouncementUpdateOne) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -724,6 +765,9 @@ func (_u *AnnouncementUpdateOne) sqlSave(ctx context.Context) (_node *Announceme
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
}
|
||||
|
||||
@@ -78,6 +78,10 @@ type Group struct {
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
// 分组显示排序,数值越小越靠前
|
||||
SortOrder int `json:"sort_order,omitempty"`
|
||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -186,13 +190,13 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject:
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
|
||||
values[i] = new(sql.NullString)
|
||||
case group.FieldCreatedAt, group.FieldUpdatedAt, group.FieldDeletedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -415,6 +419,18 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.SortOrder = int(value.Int64)
|
||||
}
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field allow_messages_dispatch", values[i])
|
||||
} else if value.Valid {
|
||||
_m.AllowMessagesDispatch = value.Bool
|
||||
}
|
||||
case group.FieldDefaultMappedModel:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DefaultMappedModel = value.String
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -608,6 +624,12 @@ func (_m *Group) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("sort_order=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SortOrder))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("allow_messages_dispatch=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("default_mapped_model=")
|
||||
builder.WriteString(_m.DefaultMappedModel)
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -75,6 +75,10 @@ const (
|
||||
FieldSupportedModelScopes = "supported_model_scopes"
|
||||
// FieldSortOrder holds the string denoting the sort_order field in the database.
|
||||
FieldSortOrder = "sort_order"
|
||||
// FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database.
|
||||
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
||||
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
||||
FieldDefaultMappedModel = "default_mapped_model"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -180,6 +184,8 @@ var Columns = []string{
|
||||
FieldMcpXMLInject,
|
||||
FieldSupportedModelScopes,
|
||||
FieldSortOrder,
|
||||
FieldAllowMessagesDispatch,
|
||||
FieldDefaultMappedModel,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -247,6 +253,12 @@ var (
|
||||
DefaultSupportedModelScopes []string
|
||||
// DefaultSortOrder holds the default value on creation for the "sort_order" field.
|
||||
DefaultSortOrder int
|
||||
// DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field.
|
||||
DefaultAllowMessagesDispatch bool
|
||||
// DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field.
|
||||
DefaultDefaultMappedModel string
|
||||
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
DefaultMappedModelValidator func(string) error
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -397,6 +409,16 @@ func BySortOrder(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSortOrder, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAllowMessagesDispatch orders the results by the allow_messages_dispatch field.
|
||||
func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDefaultMappedModel orders the results by the default_mapped_model field.
|
||||
func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -195,6 +195,16 @@ func SortOrder(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatch applies equality check predicate on the "allow_messages_dispatch" field. It's identical to AllowMessagesDispatchEQ.
|
||||
func AllowMessagesDispatch(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ.
|
||||
func DefaultMappedModel(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1470,6 +1480,81 @@ func SortOrderLTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatchEQ applies the EQ predicate on the "allow_messages_dispatch" field.
|
||||
func AllowMessagesDispatchEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatchNEQ applies the NEQ predicate on the "allow_messages_dispatch" field.
|
||||
func AllowMessagesDispatchNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelEQ(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelNEQ applies the NEQ predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelNEQ(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelIn applies the In predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelIn(vs ...string) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldDefaultMappedModel, vs...))
|
||||
}
|
||||
|
||||
// DefaultMappedModelNotIn applies the NotIn predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelNotIn(vs ...string) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldDefaultMappedModel, vs...))
|
||||
}
|
||||
|
||||
// DefaultMappedModelGT applies the GT predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelGT(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelGTE applies the GTE predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelGTE(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelLT applies the LT predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelLT(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelLTE applies the LTE predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelLTE(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelContains applies the Contains predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelContains(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContains(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelHasPrefix applies the HasPrefix predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelHasPrefix(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldHasPrefix(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelHasSuffix applies the HasSuffix predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelHasSuffix(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldHasSuffix(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelEqualFold applies the EqualFold predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelEqualFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEqualFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelContainsFold applies the ContainsFold predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelContainsFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -424,6 +424,34 @@ func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_c *GroupCreate) SetAllowMessagesDispatch(v bool) *GroupCreate {
|
||||
_c.mutation.SetAllowMessagesDispatch(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate {
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -613,6 +641,14 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultSortOrder
|
||||
_c.mutation.SetSortOrder(v)
|
||||
}
|
||||
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
|
||||
v := group.DefaultAllowMessagesDispatch
|
||||
_c.mutation.SetAllowMessagesDispatch(v)
|
||||
}
|
||||
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
|
||||
v := group.DefaultDefaultMappedModel
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -683,6 +719,17 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.SortOrder(); !ok {
|
||||
return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
|
||||
return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
|
||||
return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -830,6 +877,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
|
||||
_node.SortOrder = value
|
||||
}
|
||||
if value, ok := _c.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
_node.AllowMessagesDispatch = value
|
||||
}
|
||||
if value, ok := _c.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
_node.DefaultMappedModel = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1520,6 +1575,30 @@ func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsert) SetAllowMessagesDispatch(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldAllowMessagesDispatch, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldAllowMessagesDispatch)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert {
|
||||
u.Set(group.FieldDefaultMappedModel, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldDefaultMappedModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -2188,6 +2267,34 @@ func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsertOne) SetAllowMessagesDispatch(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowMessagesDispatch(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowMessagesDispatch()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetDefaultMappedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateDefaultMappedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -3022,6 +3129,34 @@ func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsertBulk) SetAllowMessagesDispatch(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowMessagesDispatch(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowMessagesDispatch()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetDefaultMappedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateDefaultMappedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -625,6 +625,34 @@ func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_u *GroupUpdate) SetAllowMessagesDispatch(v bool) *GroupUpdate {
|
||||
_u.mutation.SetAllowMessagesDispatch(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate {
|
||||
_u.mutation.SetDefaultMappedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -910,6 +938,11 @@ func (_u *GroupUpdate) check() error {
|
||||
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1110,6 +1143,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedSortOrder(); ok {
|
||||
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -2014,6 +2053,34 @@ func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_u *GroupUpdateOne) SetAllowMessagesDispatch(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetAllowMessagesDispatch(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne {
|
||||
_u.mutation.SetDefaultMappedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -2312,6 +2379,11 @@ func (_u *GroupUpdateOne) check() error {
|
||||
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2529,6 +2601,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.AddedSortOrder(); ok {
|
||||
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -106,6 +106,7 @@ var (
|
||||
{Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "concurrency", Type: field.TypeInt, Default: 3},
|
||||
{Name: "load_factor", Type: field.TypeInt, Nullable: true},
|
||||
{Name: "priority", Type: field.TypeInt, Default: 50},
|
||||
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||
@@ -132,7 +133,7 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "accounts_proxies_proxy",
|
||||
Columns: []*schema.Column{AccountsColumns[27]},
|
||||
Columns: []*schema.Column{AccountsColumns[28]},
|
||||
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -151,52 +152,52 @@ var (
|
||||
{
|
||||
Name: "account_status",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[13]},
|
||||
Columns: []*schema.Column{AccountsColumns[14]},
|
||||
},
|
||||
{
|
||||
Name: "account_proxy_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[27]},
|
||||
Columns: []*schema.Column{AccountsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "account_priority",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[11]},
|
||||
Columns: []*schema.Column{AccountsColumns[12]},
|
||||
},
|
||||
{
|
||||
Name: "account_last_used_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[15]},
|
||||
Columns: []*schema.Column{AccountsColumns[16]},
|
||||
},
|
||||
{
|
||||
Name: "account_schedulable",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[18]},
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limited_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limit_reset_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
Columns: []*schema.Column{AccountsColumns[21]},
|
||||
},
|
||||
{
|
||||
Name: "account_overload_until",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[21]},
|
||||
Columns: []*schema.Column{AccountsColumns[22]},
|
||||
},
|
||||
{
|
||||
Name: "account_platform_priority",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]},
|
||||
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[12]},
|
||||
},
|
||||
{
|
||||
Name: "account_priority_status",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]},
|
||||
Columns: []*schema.Column{AccountsColumns[12], AccountsColumns[14]},
|
||||
},
|
||||
{
|
||||
Name: "account_deleted_at",
|
||||
@@ -250,6 +251,7 @@ var (
|
||||
{Name: "title", Type: field.TypeString, Size: 200},
|
||||
{Name: "content", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "draft"},
|
||||
{Name: "notify_mode", Type: field.TypeString, Size: 20, Default: "silent"},
|
||||
{Name: "targeting", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "starts_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "ends_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
@@ -272,17 +274,17 @@ var (
|
||||
{
|
||||
Name: "announcement_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[9]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[10]},
|
||||
},
|
||||
{
|
||||
Name: "announcement_starts_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[5]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[6]},
|
||||
},
|
||||
{
|
||||
Name: "announcement_ends_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[6]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[7]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -406,6 +408,8 @@ var (
|
||||
{Name: "mcp_xml_inject", Type: field.TypeBool, Default: true},
|
||||
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
|
||||
@@ -2260,6 +2260,8 @@ type AccountMutation struct {
|
||||
extra *map[string]interface{}
|
||||
concurrency *int
|
||||
addconcurrency *int
|
||||
load_factor *int
|
||||
addload_factor *int
|
||||
priority *int
|
||||
addpriority *int
|
||||
rate_multiplier *float64
|
||||
@@ -2845,6 +2847,76 @@ func (m *AccountMutation) ResetConcurrency() {
|
||||
m.addconcurrency = nil
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (m *AccountMutation) SetLoadFactor(i int) {
|
||||
m.load_factor = &i
|
||||
m.addload_factor = nil
|
||||
}
|
||||
|
||||
// LoadFactor returns the value of the "load_factor" field in the mutation.
|
||||
func (m *AccountMutation) LoadFactor() (r int, exists bool) {
|
||||
v := m.load_factor
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldLoadFactor returns the old "load_factor" field's value of the Account entity.
|
||||
// If the Account object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *AccountMutation) OldLoadFactor(ctx context.Context) (v *int, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldLoadFactor is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldLoadFactor requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldLoadFactor: %w", err)
|
||||
}
|
||||
return oldValue.LoadFactor, nil
|
||||
}
|
||||
|
||||
// AddLoadFactor adds i to the "load_factor" field.
|
||||
func (m *AccountMutation) AddLoadFactor(i int) {
|
||||
if m.addload_factor != nil {
|
||||
*m.addload_factor += i
|
||||
} else {
|
||||
m.addload_factor = &i
|
||||
}
|
||||
}
|
||||
|
||||
// AddedLoadFactor returns the value that was added to the "load_factor" field in this mutation.
|
||||
func (m *AccountMutation) AddedLoadFactor() (r int, exists bool) {
|
||||
v := m.addload_factor
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (m *AccountMutation) ClearLoadFactor() {
|
||||
m.load_factor = nil
|
||||
m.addload_factor = nil
|
||||
m.clearedFields[account.FieldLoadFactor] = struct{}{}
|
||||
}
|
||||
|
||||
// LoadFactorCleared returns if the "load_factor" field was cleared in this mutation.
|
||||
func (m *AccountMutation) LoadFactorCleared() bool {
|
||||
_, ok := m.clearedFields[account.FieldLoadFactor]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetLoadFactor resets all changes to the "load_factor" field.
|
||||
func (m *AccountMutation) ResetLoadFactor() {
|
||||
m.load_factor = nil
|
||||
m.addload_factor = nil
|
||||
delete(m.clearedFields, account.FieldLoadFactor)
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (m *AccountMutation) SetPriority(i int) {
|
||||
m.priority = &i
|
||||
@@ -3773,7 +3845,7 @@ func (m *AccountMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *AccountMutation) Fields() []string {
|
||||
fields := make([]string, 0, 27)
|
||||
fields := make([]string, 0, 28)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, account.FieldCreatedAt)
|
||||
}
|
||||
@@ -3807,6 +3879,9 @@ func (m *AccountMutation) Fields() []string {
|
||||
if m.concurrency != nil {
|
||||
fields = append(fields, account.FieldConcurrency)
|
||||
}
|
||||
if m.load_factor != nil {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.priority != nil {
|
||||
fields = append(fields, account.FieldPriority)
|
||||
}
|
||||
@@ -3885,6 +3960,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.ProxyID()
|
||||
case account.FieldConcurrency:
|
||||
return m.Concurrency()
|
||||
case account.FieldLoadFactor:
|
||||
return m.LoadFactor()
|
||||
case account.FieldPriority:
|
||||
return m.Priority()
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -3948,6 +4025,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
|
||||
return m.OldProxyID(ctx)
|
||||
case account.FieldConcurrency:
|
||||
return m.OldConcurrency(ctx)
|
||||
case account.FieldLoadFactor:
|
||||
return m.OldLoadFactor(ctx)
|
||||
case account.FieldPriority:
|
||||
return m.OldPriority(ctx)
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -4066,6 +4145,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetConcurrency(v)
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetLoadFactor(v)
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
@@ -4189,6 +4275,9 @@ func (m *AccountMutation) AddedFields() []string {
|
||||
if m.addconcurrency != nil {
|
||||
fields = append(fields, account.FieldConcurrency)
|
||||
}
|
||||
if m.addload_factor != nil {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.addpriority != nil {
|
||||
fields = append(fields, account.FieldPriority)
|
||||
}
|
||||
@@ -4205,6 +4294,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) {
|
||||
switch name {
|
||||
case account.FieldConcurrency:
|
||||
return m.AddedConcurrency()
|
||||
case account.FieldLoadFactor:
|
||||
return m.AddedLoadFactor()
|
||||
case account.FieldPriority:
|
||||
return m.AddedPriority()
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -4225,6 +4316,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddConcurrency(v)
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddLoadFactor(v)
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
@@ -4256,6 +4354,9 @@ func (m *AccountMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(account.FieldProxyID) {
|
||||
fields = append(fields, account.FieldProxyID)
|
||||
}
|
||||
if m.FieldCleared(account.FieldLoadFactor) {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.FieldCleared(account.FieldErrorMessage) {
|
||||
fields = append(fields, account.FieldErrorMessage)
|
||||
}
|
||||
@@ -4312,6 +4413,9 @@ func (m *AccountMutation) ClearField(name string) error {
|
||||
case account.FieldProxyID:
|
||||
m.ClearProxyID()
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
m.ClearLoadFactor()
|
||||
return nil
|
||||
case account.FieldErrorMessage:
|
||||
m.ClearErrorMessage()
|
||||
return nil
|
||||
@@ -4386,6 +4490,9 @@ func (m *AccountMutation) ResetField(name string) error {
|
||||
case account.FieldConcurrency:
|
||||
m.ResetConcurrency()
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
m.ResetLoadFactor()
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
m.ResetPriority()
|
||||
return nil
|
||||
@@ -5060,6 +5167,7 @@ type AnnouncementMutation struct {
|
||||
title *string
|
||||
content *string
|
||||
status *string
|
||||
notify_mode *string
|
||||
targeting *domain.AnnouncementTargeting
|
||||
starts_at *time.Time
|
||||
ends_at *time.Time
|
||||
@@ -5284,6 +5392,42 @@ func (m *AnnouncementMutation) ResetStatus() {
|
||||
m.status = nil
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (m *AnnouncementMutation) SetNotifyMode(s string) {
|
||||
m.notify_mode = &s
|
||||
}
|
||||
|
||||
// NotifyMode returns the value of the "notify_mode" field in the mutation.
|
||||
func (m *AnnouncementMutation) NotifyMode() (r string, exists bool) {
|
||||
v := m.notify_mode
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldNotifyMode returns the old "notify_mode" field's value of the Announcement entity.
|
||||
// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *AnnouncementMutation) OldNotifyMode(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldNotifyMode is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldNotifyMode requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldNotifyMode: %w", err)
|
||||
}
|
||||
return oldValue.NotifyMode, nil
|
||||
}
|
||||
|
||||
// ResetNotifyMode resets all changes to the "notify_mode" field.
|
||||
func (m *AnnouncementMutation) ResetNotifyMode() {
|
||||
m.notify_mode = nil
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (m *AnnouncementMutation) SetTargeting(dt domain.AnnouncementTargeting) {
|
||||
m.targeting = &dt
|
||||
@@ -5731,7 +5875,7 @@ func (m *AnnouncementMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *AnnouncementMutation) Fields() []string {
|
||||
fields := make([]string, 0, 10)
|
||||
fields := make([]string, 0, 11)
|
||||
if m.title != nil {
|
||||
fields = append(fields, announcement.FieldTitle)
|
||||
}
|
||||
@@ -5741,6 +5885,9 @@ func (m *AnnouncementMutation) Fields() []string {
|
||||
if m.status != nil {
|
||||
fields = append(fields, announcement.FieldStatus)
|
||||
}
|
||||
if m.notify_mode != nil {
|
||||
fields = append(fields, announcement.FieldNotifyMode)
|
||||
}
|
||||
if m.targeting != nil {
|
||||
fields = append(fields, announcement.FieldTargeting)
|
||||
}
|
||||
@@ -5776,6 +5923,8 @@ func (m *AnnouncementMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.Content()
|
||||
case announcement.FieldStatus:
|
||||
return m.Status()
|
||||
case announcement.FieldNotifyMode:
|
||||
return m.NotifyMode()
|
||||
case announcement.FieldTargeting:
|
||||
return m.Targeting()
|
||||
case announcement.FieldStartsAt:
|
||||
@@ -5805,6 +5954,8 @@ func (m *AnnouncementMutation) OldField(ctx context.Context, name string) (ent.V
|
||||
return m.OldContent(ctx)
|
||||
case announcement.FieldStatus:
|
||||
return m.OldStatus(ctx)
|
||||
case announcement.FieldNotifyMode:
|
||||
return m.OldNotifyMode(ctx)
|
||||
case announcement.FieldTargeting:
|
||||
return m.OldTargeting(ctx)
|
||||
case announcement.FieldStartsAt:
|
||||
@@ -5849,6 +6000,13 @@ func (m *AnnouncementMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetStatus(v)
|
||||
return nil
|
||||
case announcement.FieldNotifyMode:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetNotifyMode(v)
|
||||
return nil
|
||||
case announcement.FieldTargeting:
|
||||
v, ok := value.(domain.AnnouncementTargeting)
|
||||
if !ok {
|
||||
@@ -6016,6 +6174,9 @@ func (m *AnnouncementMutation) ResetField(name string) error {
|
||||
case announcement.FieldStatus:
|
||||
m.ResetStatus()
|
||||
return nil
|
||||
case announcement.FieldNotifyMode:
|
||||
m.ResetNotifyMode()
|
||||
return nil
|
||||
case announcement.FieldTargeting:
|
||||
m.ResetTargeting()
|
||||
return nil
|
||||
@@ -8089,6 +8250,8 @@ type GroupMutation struct {
|
||||
appendsupported_model_scopes []string
|
||||
sort_order *int
|
||||
addsort_order *int
|
||||
allow_messages_dispatch *bool
|
||||
default_mapped_model *string
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -9833,6 +9996,78 @@ func (m *GroupMutation) ResetSortOrder() {
|
||||
m.addsort_order = nil
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (m *GroupMutation) SetAllowMessagesDispatch(b bool) {
|
||||
m.allow_messages_dispatch = &b
|
||||
}
|
||||
|
||||
// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation.
|
||||
func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) {
|
||||
v := m.allow_messages_dispatch
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err)
|
||||
}
|
||||
return oldValue.AllowMessagesDispatch, nil
|
||||
}
|
||||
|
||||
// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field.
|
||||
func (m *GroupMutation) ResetAllowMessagesDispatch() {
|
||||
m.allow_messages_dispatch = nil
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (m *GroupMutation) SetDefaultMappedModel(s string) {
|
||||
m.default_mapped_model = &s
|
||||
}
|
||||
|
||||
// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation.
|
||||
func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) {
|
||||
v := m.default_mapped_model
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err)
|
||||
}
|
||||
return oldValue.DefaultMappedModel, nil
|
||||
}
|
||||
|
||||
// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field.
|
||||
func (m *GroupMutation) ResetDefaultMappedModel() {
|
||||
m.default_mapped_model = nil
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -10191,7 +10426,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 30)
|
||||
fields := make([]string, 0, 32)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@@ -10282,6 +10517,12 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.sort_order != nil {
|
||||
fields = append(fields, group.FieldSortOrder)
|
||||
}
|
||||
if m.allow_messages_dispatch != nil {
|
||||
fields = append(fields, group.FieldAllowMessagesDispatch)
|
||||
}
|
||||
if m.default_mapped_model != nil {
|
||||
fields = append(fields, group.FieldDefaultMappedModel)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -10350,6 +10591,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.SupportedModelScopes()
|
||||
case group.FieldSortOrder:
|
||||
return m.SortOrder()
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
return m.AllowMessagesDispatch()
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.DefaultMappedModel()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -10419,6 +10664,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldSupportedModelScopes(ctx)
|
||||
case group.FieldSortOrder:
|
||||
return m.OldSortOrder(ctx)
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
return m.OldAllowMessagesDispatch(ctx)
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.OldDefaultMappedModel(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -10638,6 +10887,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetSortOrder(v)
|
||||
return nil
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetAllowMessagesDispatch(v)
|
||||
return nil
|
||||
case group.FieldDefaultMappedModel:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetDefaultMappedModel(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -11065,6 +11328,12 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldSortOrder:
|
||||
m.ResetSortOrder()
|
||||
return nil
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
m.ResetAllowMessagesDispatch()
|
||||
return nil
|
||||
case group.FieldDefaultMappedModel:
|
||||
m.ResetDefaultMappedModel()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
|
||||
@@ -212,29 +212,29 @@ func init() {
|
||||
// account.DefaultConcurrency holds the default value on creation for the concurrency field.
|
||||
account.DefaultConcurrency = accountDescConcurrency.Default.(int)
|
||||
// accountDescPriority is the schema descriptor for priority field.
|
||||
accountDescPriority := accountFields[8].Descriptor()
|
||||
accountDescPriority := accountFields[9].Descriptor()
|
||||
// account.DefaultPriority holds the default value on creation for the priority field.
|
||||
account.DefaultPriority = accountDescPriority.Default.(int)
|
||||
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||
accountDescRateMultiplier := accountFields[9].Descriptor()
|
||||
accountDescRateMultiplier := accountFields[10].Descriptor()
|
||||
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
|
||||
// accountDescStatus is the schema descriptor for status field.
|
||||
accountDescStatus := accountFields[10].Descriptor()
|
||||
accountDescStatus := accountFields[11].Descriptor()
|
||||
// account.DefaultStatus holds the default value on creation for the status field.
|
||||
account.DefaultStatus = accountDescStatus.Default.(string)
|
||||
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
|
||||
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
|
||||
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
|
||||
accountDescAutoPauseOnExpired := accountFields[15].Descriptor()
|
||||
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
|
||||
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
|
||||
// accountDescSchedulable is the schema descriptor for schedulable field.
|
||||
accountDescSchedulable := accountFields[15].Descriptor()
|
||||
accountDescSchedulable := accountFields[16].Descriptor()
|
||||
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
||||
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
||||
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
||||
accountDescSessionWindowStatus := accountFields[23].Descriptor()
|
||||
accountDescSessionWindowStatus := accountFields[24].Descriptor()
|
||||
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
||||
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
||||
accountgroupFields := schema.AccountGroup{}.Fields()
|
||||
@@ -277,12 +277,18 @@ func init() {
|
||||
announcement.DefaultStatus = announcementDescStatus.Default.(string)
|
||||
// announcement.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
announcement.StatusValidator = announcementDescStatus.Validators[0].(func(string) error)
|
||||
// announcementDescNotifyMode is the schema descriptor for notify_mode field.
|
||||
announcementDescNotifyMode := announcementFields[3].Descriptor()
|
||||
// announcement.DefaultNotifyMode holds the default value on creation for the notify_mode field.
|
||||
announcement.DefaultNotifyMode = announcementDescNotifyMode.Default.(string)
|
||||
// announcement.NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save.
|
||||
announcement.NotifyModeValidator = announcementDescNotifyMode.Validators[0].(func(string) error)
|
||||
// announcementDescCreatedAt is the schema descriptor for created_at field.
|
||||
announcementDescCreatedAt := announcementFields[8].Descriptor()
|
||||
announcementDescCreatedAt := announcementFields[9].Descriptor()
|
||||
// announcement.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
announcement.DefaultCreatedAt = announcementDescCreatedAt.Default.(func() time.Time)
|
||||
// announcementDescUpdatedAt is the schema descriptor for updated_at field.
|
||||
announcementDescUpdatedAt := announcementFields[9].Descriptor()
|
||||
announcementDescUpdatedAt := announcementFields[10].Descriptor()
|
||||
// announcement.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||
announcement.DefaultUpdatedAt = announcementDescUpdatedAt.Default.(func() time.Time)
|
||||
// announcement.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
@@ -447,6 +453,16 @@ func init() {
|
||||
groupDescSortOrder := groupFields[26].Descriptor()
|
||||
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
||||
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
||||
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
|
||||
groupDescAllowMessagesDispatch := groupFields[27].Descriptor()
|
||||
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
|
||||
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
|
||||
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
|
||||
groupDescDefaultMappedModel := groupFields[28].Descriptor()
|
||||
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
|
||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||
_ = idempotencyrecordMixinFields0
|
||||
|
||||
@@ -97,6 +97,8 @@ func (Account) Fields() []ent.Field {
|
||||
field.Int("concurrency").
|
||||
Default(3),
|
||||
|
||||
field.Int("load_factor").Optional().Nillable(),
|
||||
|
||||
// priority: 账户优先级,数值越小优先级越高
|
||||
// 调度器会优先使用高优先级的账户
|
||||
field.Int("priority").
|
||||
|
||||
@@ -41,6 +41,10 @@ func (Announcement) Fields() []ent.Field {
|
||||
MaxLen(20).
|
||||
Default(domain.AnnouncementStatusDraft).
|
||||
Comment("状态: draft, active, archived"),
|
||||
field.String("notify_mode").
|
||||
MaxLen(20).
|
||||
Default(domain.AnnouncementNotifyModeSilent).
|
||||
Comment("通知模式: silent(仅铃铛), popup(弹窗提醒)"),
|
||||
field.JSON("targeting", domain.AnnouncementTargeting{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
|
||||
@@ -148,6 +148,15 @@ func (Group) Fields() []ent.Field {
|
||||
field.Int("sort_order").
|
||||
Default(0).
|
||||
Comment("分组显示排序,数值越小越靠前"),
|
||||
|
||||
// OpenAI Messages 调度配置 (added by migration 069)
|
||||
field.Bool("allow_messages_dispatch").
|
||||
Default(false).
|
||||
Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"),
|
||||
field.String("default_mapped_model").
|
||||
MaxLen(100).
|
||||
Default("").
|
||||
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
module github.com/Wei-Shaw/sub2api
|
||||
|
||||
go 1.25.7
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
entgo.io/ent v0.14.5
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
@@ -38,8 +39,6 @@ require (
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/term v0.40.0
|
||||
google.golang.org/grpc v1.75.1
|
||||
google.golang.org/protobuf v1.36.10
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.44.3
|
||||
@@ -53,7 +52,6 @@ require (
|
||||
github.com/agext/levenshtein v1.2.3 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
|
||||
@@ -68,7 +66,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||
github.com/aws/smithy-go v1.24.1 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
@@ -109,7 +107,6 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||
@@ -169,6 +166,7 @@ require (
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||
go.opentelemetry.io/otel v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||
go.uber.org/atomic v1.10.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
@@ -178,8 +176,8 @@ require (
|
||||
golang.org/x/mod v0.32.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/tools v0.41.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
||||
@@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||
@@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||
@@ -171,8 +175,6 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
@@ -182,8 +184,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@@ -398,8 +398,6 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd
|
||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
|
||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
|
||||
@@ -455,8 +453,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
|
||||
|
||||
@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
||||
type GatewayOpenAIWSConfig struct {
|
||||
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
||||
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
||||
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
|
||||
// IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough)
|
||||
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
||||
// Enabled: 全局总开关(默认 true)
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
|
||||
|
||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||
type DashboardAggregationRetentionConfig struct {
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
}
|
||||
|
||||
// UsageCleanupConfig 使用记录清理任务配置
|
||||
@@ -1227,7 +1228,7 @@ func setDefaults() {
|
||||
|
||||
// Ops (vNext)
|
||||
viper.SetDefault("ops.enabled", true)
|
||||
viper.SetDefault("ops.use_preaggregated_tables", false)
|
||||
viper.SetDefault("ops.use_preaggregated_tables", true)
|
||||
viper.SetDefault("ops.cleanup.enabled", true)
|
||||
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
|
||||
// Retention days: vNext defaults to 30 days across ops datasets.
|
||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||
@@ -1335,7 +1337,7 @@ func setDefaults() {
|
||||
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
|
||||
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.force_http", false)
|
||||
@@ -1402,7 +1404,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||
}
|
||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||
}
|
||||
@@ -2043,9 +2059,11 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
||||
switch mode {
|
||||
case "off", "shared", "dedicated":
|
||||
case "off", "ctx_pool", "passthrough":
|
||||
case "shared", "dedicated":
|
||||
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
|
||||
default:
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
|
||||
}
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
||||
|
||||
@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
|
||||
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
||||
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||
}
|
||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation disabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||
@@ -1373,7 +1393,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
|
||||
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
||||
},
|
||||
{
|
||||
name: "ingress_mode_default 必须为 off|shared|dedicated",
|
||||
name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
||||
wantErr: "gateway.openai_ws.ingress_mode_default",
|
||||
},
|
||||
|
||||
@@ -13,6 +13,11 @@ const (
|
||||
AnnouncementStatusArchived = "archived"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementNotifyModeSilent = "silent"
|
||||
AnnouncementNotifyModePopup = "popup"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = "subscription"
|
||||
AnnouncementConditionTypeBalance = "balance"
|
||||
@@ -195,17 +200,18 @@ func (c AnnouncementCondition) validate() error {
|
||||
}
|
||||
|
||||
type Announcement struct {
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
NotifyMode string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (a *Announcement) IsActiveAt(now time.Time) bool {
|
||||
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -84,10 +85,12 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||
// Gemini 2.5 白名单
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
// Gemini 3 白名单
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
@@ -111,3 +114,27 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
}
|
||||
|
||||
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||
var DefaultBedrockModelMapping = map[string]string{
|
||||
// Claude Opus
|
||||
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
// Claude Sonnet
|
||||
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
// Claude Haiku
|
||||
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string]string{
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
||||
}
|
||||
}
|
||||
|
||||
enrichCredentialsFromIDToken(&item)
|
||||
|
||||
accountInput := &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
@@ -535,6 +540,57 @@ func defaultProxyName(name string) string {
|
||||
return name
|
||||
}
|
||||
|
||||
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
||||
// Existing credential values are never overwritten — only missing fields are filled.
|
||||
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||
if item.Credentials == nil {
|
||||
return
|
||||
}
|
||||
// Only enrich OpenAI/Sora OAuth accounts
|
||||
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
||||
return
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
|
||||
idToken, _ := item.Credentials["id_token"].(string)
|
||||
if strings.TrimSpace(idToken) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeIDToken skips expiry validation — safe for imported data
|
||||
claims, err := openai.DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo := claims.GetUserInfo()
|
||||
if userInfo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fill missing fields only (never overwrite existing values)
|
||||
setIfMissing := func(key, value string) {
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
if existing, _ := item.Credentials[key].(string); existing == "" {
|
||||
item.Credentials[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
setIfMissing("email", userInfo.Email)
|
||||
setIfMissing("plan_type", userInfo.PlanType)
|
||||
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
|
||||
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
|
||||
setIfMissing("organization_id", userInfo.OrganizationID)
|
||||
}
|
||||
|
||||
func normalizeProxyStatus(status string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(status))
|
||||
switch normalized {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -95,13 +97,14 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
@@ -113,14 +116,15 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
@@ -135,6 +139,7 @@ type BulkUpdateAccountsRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
@@ -217,6 +222,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
lite := parseBoolQueryWithDefault(c.Query("lite"), false)
|
||||
|
||||
var groupID int64
|
||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||
@@ -235,10 +241,16 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
accountIDs[i] = acc.ID
|
||||
}
|
||||
|
||||
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
|
||||
if err != nil {
|
||||
// Log error but don't fail the request, just use 0 for all
|
||||
concurrencyCounts = make(map[int64]int)
|
||||
concurrencyCounts := make(map[int64]int)
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
var rpmCounts map[int64]int
|
||||
|
||||
// 始终获取并发数(Redis ZCARD,极低开销)
|
||||
if h.concurrencyService != nil {
|
||||
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
|
||||
concurrencyCounts = cc
|
||||
}
|
||||
}
|
||||
|
||||
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
@@ -262,12 +274,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 并行获取窗口费用、活跃会话数和 RPM 计数
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
var rpmCounts map[int64]int
|
||||
|
||||
// 获取 RPM 计数(批量查询)
|
||||
// 始终获取 RPM 计数(Redis GET,极低开销)
|
||||
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
||||
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
||||
if rpmCounts == nil {
|
||||
@@ -275,7 +282,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
// 始终获取活跃会话数(Redis ZCARD,低开销)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
@@ -283,7 +290,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获取窗口费用(并行查询)
|
||||
// 始终获取窗口费用(PostgreSQL 聚合查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
@@ -344,7 +351,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
result[i] = item
|
||||
}
|
||||
|
||||
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search)
|
||||
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite)
|
||||
if etag != "" {
|
||||
c.Header("ETag", etag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
@@ -362,6 +369,7 @@ func buildAccountsListETag(
|
||||
total int64,
|
||||
page, pageSize int,
|
||||
platform, accountType, status, search string,
|
||||
lite bool,
|
||||
) string {
|
||||
payload := struct {
|
||||
Total int64 `json:"total"`
|
||||
@@ -371,6 +379,7 @@ func buildAccountsListETag(
|
||||
AccountType string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Search string `json:"search"`
|
||||
Lite bool `json:"lite"`
|
||||
Items []AccountWithConcurrency `json:"items"`
|
||||
}{
|
||||
Total: total,
|
||||
@@ -380,6 +389,7 @@ func buildAccountsListETag(
|
||||
AccountType: accountType,
|
||||
Status: status,
|
||||
Search: search,
|
||||
Lite: lite,
|
||||
Items: items,
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
@@ -501,6 +511,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
@@ -570,6 +581,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
@@ -616,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
// TestAccountRequest represents the request body for testing an account
|
||||
type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
@@ -646,10 +659,46 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// Use AccountTestService to test the account with SSE streaming
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService != nil {
|
||||
if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverState handles unified recovery of recoverable account runtime state.
|
||||
// POST /api/v1/admin/accounts/:id/recover-state
|
||||
func (h *AccountHandler) RecoverState(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
}); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
@@ -705,52 +754,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts (oauth and setup-token)
|
||||
// refreshSingleAccount refreshes credentials for a single OAuth account.
|
||||
// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario.
|
||||
func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) {
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account")
|
||||
}
|
||||
|
||||
var newCredentials map[string]any
|
||||
|
||||
if account.IsOpenAI() {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformGemini {
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to refresh credentials: %w", err)
|
||||
}
|
||||
|
||||
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -760,10 +788,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformAntigravity {
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -782,37 +809,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
return updatedAccount, "missing_project_id_temporary", nil
|
||||
}
|
||||
|
||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
||||
return
|
||||
if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil {
|
||||
return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
@@ -834,20 +851,54 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
|
||||
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
|
||||
|
||||
return updatedAccount, "", nil
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
if warning == "missing_project_id_temporary" {
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||
@@ -903,14 +954,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// BatchClearError handles batch clearing account errors
|
||||
// POST /api/v1/admin/accounts/batch-clear-error
|
||||
func (h *AccountHandler) BatchClearError(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, id := range req.AccountIDs {
|
||||
accountID := id // 闭包捕获
|
||||
g.Go(func() error {
|
||||
account, err := h.adminService.ClearAccountError(gctx, accountID)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": accountID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 清除错误后,同时清除 token 缓存
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
successCount++
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchRefresh handles batch refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/batch-refresh
|
||||
func (h *AccountHandler) BatchRefresh(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 建立已获取账号的 ID 集合,检测缺失的 ID
|
||||
foundIDs := make(map[int64]bool, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc != nil {
|
||||
foundIDs[acc.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
var warnings []gin.H
|
||||
|
||||
// 将不存在的账号 ID 标记为失败
|
||||
for _, id := range req.AccountIDs {
|
||||
if !foundIDs[id] {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": id,
|
||||
"error": "account not found",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, account := range accounts {
|
||||
acc := account // 闭包捕获
|
||||
if acc == nil {
|
||||
continue
|
||||
}
|
||||
g.Go(func() error {
|
||||
_, warning, err := h.refreshSingleAccount(gctx, acc)
|
||||
mu.Lock()
|
||||
if err != nil {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
successCount++
|
||||
if warning != "" {
|
||||
warnings = append(warnings, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"warning": warning,
|
||||
})
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating accounts
|
||||
// POST /api/v1/admin/accounts/batch
|
||||
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
@@ -1096,6 +1308,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.RateMultiplier != nil ||
|
||||
req.LoadFactor != nil ||
|
||||
req.Status != "" ||
|
||||
req.Schedulable != nil ||
|
||||
req.GroupIDs != nil ||
|
||||
@@ -1114,6 +1327,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
Status: req.Status,
|
||||
Schedulable: req.Schedulable,
|
||||
GroupIDs: req.GroupIDs,
|
||||
@@ -1323,6 +1537,29 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// ResetQuota handles resetting account quota usage
|
||||
// POST /api/v1/admin/accounts/:id/reset-quota
|
||||
func (h *AccountHandler) ResetQuota(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ResetAccountQuota(c.Request.Context(), accountID); err != nil {
|
||||
response.InternalError(c, "Failed to reset account quota: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||
// GET /api/v1/admin/accounts/:id/temp-unschedulable
|
||||
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {
|
||||
@@ -1398,18 +1635,41 @@ func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
accountIDs := normalizeInt64IDList(req.AccountIDs)
|
||||
if len(accountIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
|
||||
cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs)
|
||||
if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
payload := gin.H{"stats": stats}
|
||||
cached := accountTodayStatsBatchCache.Set(cacheKey, payload)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// SetSchedulableRequest represents the request body for setting schedulable status
|
||||
@@ -1458,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||
if account.IsOpenAIPassthroughEnabled() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type availableModelsAdminService struct {
|
||||
*stubAdminService
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID == id {
|
||||
acc := s.account
|
||||
return &acc, nil
|
||||
}
|
||||
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||
}
|
||||
|
||||
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 42,
|
||||
Name: "openai-oauth",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 1)
|
||||
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 43,
|
||||
Name: "openai-oauth-passthrough",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
25
backend/internal/handler/admin/account_today_stats_cache.go
Normal file
25
backend/internal/handler/admin/account_today_stats_cache.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string {
|
||||
if len(accountIDs) == 0 {
|
||||
return "accounts_today_stats_empty"
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(accountIDs) * 6)
|
||||
_, _ = b.WriteString("accounts_today_stats:")
|
||||
for i, id := range accountIDs {
|
||||
if i > 0 {
|
||||
_ = b.WriteByte(',')
|
||||
}
|
||||
_, _ = b.WriteString(strconv.FormatInt(id, 10))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
@@ -175,6 +175,18 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
@@ -425,5 +437,13 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
@@ -27,21 +27,23 @@ func NewAnnouncementHandler(announcementService *service.AnnouncementService) *A
|
||||
}
|
||||
|
||||
type CreateAnnouncementRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
NotifyMode string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
|
||||
}
|
||||
|
||||
type UpdateAnnouncementRequest struct {
|
||||
Title *string `json:"title"`
|
||||
Content *string `json:"content"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting *service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
|
||||
Title *string `json:"title"`
|
||||
Content *string `json:"content"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
NotifyMode *string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
|
||||
Targeting *service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
|
||||
}
|
||||
|
||||
// List handles listing announcements with filters
|
||||
@@ -110,11 +112,12 @@ func (h *AnnouncementHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
input := &service.CreateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
NotifyMode: req.NotifyMode,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil && *req.StartsAt > 0 {
|
||||
@@ -157,11 +160,12 @@ func (h *AnnouncementHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
input := &service.UpdateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
NotifyMode: req.NotifyMode,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil {
|
||||
|
||||
204
backend/internal/handler/admin/backup_handler.go
Normal file
204
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupHandler struct {
|
||||
backupService *service.BackupService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||
return &BackupHandler{
|
||||
backupService: backupService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置 ───
|
||||
|
||||
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||
}
|
||||
|
||||
// ─── 定时备份 ───
|
||||
|
||||
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||
var req service.BackupScheduleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// ─── 备份操作 ───
|
||||
|
||||
type CreateBackupRequest struct {
|
||||
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||
}
|
||||
|
||||
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||
var req CreateBackupRequest
|
||||
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||
|
||||
expireDays := 14 // 默认14天过期
|
||||
if req.ExpireDays != nil {
|
||||
expireDays = *req.ExpireDays
|
||||
}
|
||||
|
||||
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if records == nil {
|
||||
records = []service.BackupRecord{}
|
||||
}
|
||||
response.Success(c, gin.H{"items": records})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||
|
||||
type RestoreBackupRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req RestoreBackupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "password is required for restore operation")
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取当前管理员用户 ID
|
||||
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取管理员用户并验证密码
|
||||
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !user.CheckPassword(req.Password) {
|
||||
response.BadRequest(c, "incorrect admin password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"restored": true})
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -248,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -320,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
@@ -390,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"groups": stats,
|
||||
@@ -415,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -441,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -460,6 +466,60 @@ type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func parseRankingLimit(raw string) int {
|
||||
limit, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || limit <= 0 {
|
||||
return 12
|
||||
}
|
||||
if limit > 50 {
|
||||
return 50
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// GetUserSpendingRanking handles getting user spending ranking data.
|
||||
// GET /api/v1/admin/dashboard/users-ranking
|
||||
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
|
||||
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
Start string `json:"start"`
|
||||
End string `json:"end"`
|
||||
Limit int `json:"limit"`
|
||||
}{
|
||||
Start: startTime.UTC().Format(time.RFC3339),
|
||||
End: endTime.UTC().Format(time.RFC3339),
|
||||
Limit: limit,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user spending ranking")
|
||||
return
|
||||
}
|
||||
|
||||
payload := gin.H{
|
||||
"ranking": ranking.Ranking,
|
||||
"total_actual_cost": ranking.TotalActualCost,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
}
|
||||
dashboardUsersRankingCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
@@ -469,18 +529,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
userIDs := normalizeInt64IDList(req.UserIDs)
|
||||
if len(userIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
UserIDs []int64 `json:"user_ids"`
|
||||
}{
|
||||
UserIDs: userIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
payload := gin.H{"stats": stats}
|
||||
dashboardBatchUsersUsageCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
|
||||
@@ -497,16 +573,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.APIKeyIDs) == 0 {
|
||||
apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
APIKeyIDs []int64 `json:"api_key_ids"`
|
||||
}{
|
||||
APIKeyIDs: apiKeyIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
payload := gin.H{"stats": stats}
|
||||
dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardUsageRepoCacheProbe struct {
|
||||
service.UsageLogRepository
|
||||
trendCalls atomic.Int32
|
||||
usersTrendCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, error) {
|
||||
r.trendCalls.Add(1)
|
||||
return []usagestats.TrendDataPoint{{
|
||||
Date: "2026-03-11",
|
||||
Requests: 1,
|
||||
TotalTokens: 2,
|
||||
Cost: 3,
|
||||
ActualCost: 4,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
limit int,
|
||||
) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
r.usersTrendCalls.Add(1)
|
||||
return []usagestats.UserUsageTrendPoint{{
|
||||
Date: "2026-03-11",
|
||||
UserID: 1,
|
||||
Email: "cache@test.dev",
|
||||
Requests: 2,
|
||||
Tokens: 20,
|
||||
Cost: 2,
|
||||
ActualCost: 1,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func resetDashboardReadCachesForTest() {
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.trendCalls.Load())
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
|
||||
}
|
||||
@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
|
||||
trendStream *bool
|
||||
modelRequestType *int16
|
||||
modelStream *bool
|
||||
rankingLimit int
|
||||
ranking []usagestats.UserSpendingRankingItem
|
||||
rankingTotal float64
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||
@@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
||||
return []usagestats.ModelStat{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
limit int,
|
||||
) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
s.rankingLimit = limit
|
||||
return &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: s.ranking,
|
||||
TotalActualCost: s.rankingTotal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
@@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
|
||||
return router
|
||||
}
|
||||
|
||||
@@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
repo := &dashboardUsageRepoCapture{
|
||||
ranking: []usagestats.UserSpendingRankingItem{
|
||||
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
|
||||
},
|
||||
rankingTotal: 88.8,
|
||||
}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 50, repo.rankingLimit)
|
||||
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
}
|
||||
|
||||
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
var (
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
)
|
||||
|
||||
type dashboardTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Model string `json:"model"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardModelGroupCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardEntityTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
func cacheStatusValue(hit bool) string {
|
||||
if hit {
|
||||
return "hit"
|
||||
}
|
||||
return "miss"
|
||||
}
|
||||
|
||||
func mustMarshalDashboardCacheKey(value any) string {
|
||||
raw, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func snapshotPayloadAs[T any](payload any) (T, error) {
|
||||
typed, ok := payload.(T)
|
||||
if !ok {
|
||||
var zero T
|
||||
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
|
||||
}
|
||||
return typed, nil
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUsageTrendCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getModelStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.ModelStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getGroupStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.GroupStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
302
backend/internal/handler/admin/dashboard_snapshot_v2_handler.go
Normal file
302
backend/internal/handler/admin/dashboard_snapshot_v2_handler.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
type dashboardSnapshotV2Stats struct {
|
||||
usagestats.DashboardStats
|
||||
Uptime int64 `json:"uptime"`
|
||||
}
|
||||
|
||||
type dashboardSnapshotV2Response struct {
|
||||
GeneratedAt string `json:"generated_at"`
|
||||
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
Granularity string `json:"granularity"`
|
||||
|
||||
Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"`
|
||||
Trend []usagestats.TrendDataPoint `json:"trend,omitempty"`
|
||||
Models []usagestats.ModelStat `json:"models,omitempty"`
|
||||
Groups []usagestats.GroupStat `json:"groups,omitempty"`
|
||||
UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"`
|
||||
}
|
||||
|
||||
type dashboardSnapshotV2Filters struct {
|
||||
UserID int64
|
||||
APIKeyID int64
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
RequestType *int16
|
||||
Stream *bool
|
||||
BillingType *int8
|
||||
}
|
||||
|
||||
type dashboardSnapshotV2CacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Model string `json:"model"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
IncludeStats bool `json:"include_stats"`
|
||||
IncludeTrend bool `json:"include_trend"`
|
||||
IncludeModels bool `json:"include_models"`
|
||||
IncludeGroups bool `json:"include_groups"`
|
||||
IncludeUsersTrend bool `json:"include_users_trend"`
|
||||
UsersTrendLimit int `json:"users_trend_limit"`
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day"))
|
||||
if granularity != "hour" {
|
||||
granularity = "day"
|
||||
}
|
||||
|
||||
includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true)
|
||||
includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true)
|
||||
includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true)
|
||||
includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false)
|
||||
includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false)
|
||||
usersTrendLimit := 12
|
||||
if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" {
|
||||
if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 {
|
||||
usersTrendLimit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
filters, err := parseDashboardSnapshotV2Filters(c)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
UserID: filters.UserID,
|
||||
APIKeyID: filters.APIKeyID,
|
||||
AccountID: filters.AccountID,
|
||||
GroupID: filters.GroupID,
|
||||
Model: filters.Model,
|
||||
RequestType: filters.RequestType,
|
||||
Stream: filters.Stream,
|
||||
BillingType: filters.BillingType,
|
||||
IncludeStats: includeStats,
|
||||
IncludeTrend: includeTrend,
|
||||
IncludeModels: includeModels,
|
||||
IncludeGroups: includeGroups,
|
||||
IncludeUsersTrend: includeUsersTrend,
|
||||
UsersTrendLimit: usersTrendLimit,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
|
||||
return h.buildSnapshotV2Response(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
filters,
|
||||
includeStats,
|
||||
includeTrend,
|
||||
includeModels,
|
||||
includeGroups,
|
||||
includeUsersTrend,
|
||||
usersTrendLimit,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
response.Error(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
response.Success(c, cached.Payload)
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) buildSnapshotV2Response(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
filters *dashboardSnapshotV2Filters,
|
||||
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
|
||||
usersTrendLimit int,
|
||||
) (*dashboardSnapshotV2Response, error) {
|
||||
resp := &dashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
StartDate: startTime.Format("2006-01-02"),
|
||||
EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
Granularity: granularity,
|
||||
}
|
||||
|
||||
if includeStats {
|
||||
stats, err := h.dashboardService.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get dashboard statistics")
|
||||
}
|
||||
resp.Stats = &dashboardSnapshotV2Stats{
|
||||
DashboardStats: *stats,
|
||||
Uptime: int64(time.Since(h.startTime).Seconds()),
|
||||
}
|
||||
}
|
||||
|
||||
if includeTrend {
|
||||
trend, _, err := h.getUsageTrendCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
filters.UserID,
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
filters.Model,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get usage trend")
|
||||
}
|
||||
resp.Trend = trend
|
||||
}
|
||||
|
||||
if includeModels {
|
||||
models, _, err := h.getModelStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get model statistics")
|
||||
}
|
||||
resp.Models = models
|
||||
}
|
||||
|
||||
if includeGroups {
|
||||
groups, _, err := h.getGroupStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get group statistics")
|
||||
}
|
||||
resp.Groups = groups
|
||||
}
|
||||
|
||||
if includeUsersTrend {
|
||||
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get user usage trend")
|
||||
}
|
||||
resp.UsersTrend = usersTrend
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||
filters := &dashboardSnapshotV2Filters{
|
||||
Model: strings.TrimSpace(c.Query("model")),
|
||||
}
|
||||
|
||||
if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.UserID = id
|
||||
}
|
||||
if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.APIKeyID = id
|
||||
}
|
||||
if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" {
|
||||
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.AccountID = id
|
||||
}
|
||||
if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" {
|
||||
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.GroupID = id
|
||||
}
|
||||
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value := int16(parsed)
|
||||
filters.RequestType = &value
|
||||
} else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" {
|
||||
streamVal, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.Stream = &streamVal
|
||||
}
|
||||
|
||||
if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" {
|
||||
v, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bt := int8(v)
|
||||
filters.BillingType = &bt
|
||||
}
|
||||
|
||||
return filters, nil
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -16,6 +19,55 @@ type GroupHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
type optionalLimitField struct {
|
||||
set bool
|
||||
value *float64
|
||||
}
|
||||
|
||||
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
|
||||
f.set = true
|
||||
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if bytes.Equal(trimmed, []byte("null")) {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var number float64
|
||||
if err := json.Unmarshal(trimmed, &number); err == nil {
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(trimmed, &text); err == nil {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
number, err = strconv.ParseFloat(text, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
|
||||
}
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid limit value: %s", string(trimmed))
|
||||
}
|
||||
|
||||
func (f optionalLimitField) ToServiceInput() *float64 {
|
||||
if !f.set {
|
||||
return nil
|
||||
}
|
||||
if f.value != nil {
|
||||
return f.value
|
||||
}
|
||||
zero := 0.0
|
||||
return &zero
|
||||
}
|
||||
|
||||
// NewGroupHandler creates a new admin group handler
|
||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
return &GroupHandler{
|
||||
@@ -25,15 +77,15 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
|
||||
// CreateGroupRequest represents create group request
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
@@ -53,22 +105,25 @@ type CreateGroupRequest struct {
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
@@ -88,6 +143,9 @@ type UpdateGroupRequest struct {
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel *string `json:"default_mapped_model"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -185,9 +243,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@@ -203,6 +261,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -236,9 +296,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@@ -254,6 +314,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -325,6 +387,72 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
|
||||
// GET /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if entries == nil {
|
||||
entries = []service.UserGroupRateEntry{}
|
||||
}
|
||||
response.Success(c, entries)
|
||||
}
|
||||
|
||||
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
|
||||
// DELETE /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
|
||||
}
|
||||
|
||||
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
|
||||
type BatchSetGroupRateMultipliersRequest struct {
|
||||
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
|
||||
// PUT /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchSetGroupRateMultipliersRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||
}
|
||||
|
||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||
type UpdateSortOrderRequest struct {
|
||||
Updates []struct {
|
||||
|
||||
25
backend/internal/handler/admin/id_list_utils.go
Normal file
25
backend/internal/handler/admin/id_list_utils.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package admin
|
||||
|
||||
import "sort"
|
||||
|
||||
func normalizeInt64IDList(ids []int64) []int64 {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]int64, 0, len(ids))
|
||||
seen := make(map[int64]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
|
||||
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
|
||||
return out
|
||||
}
|
||||
57
backend/internal/handler/admin/id_list_utils_test.go
Normal file
57
backend/internal/handler/admin/id_list_utils_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormalizeInt64IDList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in []int64
|
||||
want []int64
|
||||
}{
|
||||
{"nil input", nil, nil},
|
||||
{"empty input", []int64{}, nil},
|
||||
{"single element", []int64{5}, []int64{5}},
|
||||
{"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}},
|
||||
{"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}},
|
||||
{"zero filtered", []int64{0, 1, 2}, []int64{1, 2}},
|
||||
{"negative filtered", []int64{-5, -1, 3}, []int64{3}},
|
||||
{"all invalid", []int64{0, -1, -2}, []int64{}},
|
||||
{"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := normalizeInt64IDList(tc.in)
|
||||
if tc.want == nil {
|
||||
require.Nil(t, got)
|
||||
} else {
|
||||
require.Equal(t, tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []int64
|
||||
want string
|
||||
}{
|
||||
{"empty", nil, "accounts_today_stats_empty"},
|
||||
{"single", []int64{42}, "accounts_today_stats:42"},
|
||||
{"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := buildAccountTodayStatsBatchCacheKey(tc.ids)
|
||||
require.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
Platform: platform,
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
Extra: nil,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
|
||||
@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent",
|
||||
"concurrency_queue_depth",
|
||||
"group_available_accounts",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_rate_limited_count",
|
||||
"account_error_count",
|
||||
"account_error_ratio",
|
||||
"overload_account_count",
|
||||
}
|
||||
|
||||
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
||||
@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
|
||||
"error_rate",
|
||||
"upstream_error_rate",
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent":
|
||||
"memory_usage_percent",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_error_ratio":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
145
backend/internal/handler/admin/ops_snapshot_v2_handler.go
Normal file
145
backend/internal/handler/admin/ops_snapshot_v2_handler.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
type opsDashboardSnapshotV2Response struct {
|
||||
GeneratedAt string `json:"generated_at"`
|
||||
|
||||
Overview *service.OpsDashboardOverview `json:"overview"`
|
||||
ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"`
|
||||
ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"`
|
||||
}
|
||||
|
||||
type opsDashboardSnapshotV2CacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
QueryMode service.OpsQueryMode `json:"mode"`
|
||||
BucketSecond int `json:"bucket_second"`
|
||||
}
|
||||
|
||||
// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request.
|
||||
// GET /api/v1/admin/ops/dashboard/snapshot-v2
|
||||
func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsDashboardFilter{
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
Platform: strings.TrimSpace(c.Query("platform")),
|
||||
QueryMode: parseOpsQueryMode(c),
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||
id, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
filter.GroupID = &id
|
||||
}
|
||||
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
|
||||
|
||||
keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Platform: filter.Platform,
|
||||
GroupID: filter.GroupID,
|
||||
QueryMode: filter.QueryMode,
|
||||
BucketSecond: bucketSeconds,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
overview *service.OpsDashboardOverview
|
||||
trend *service.OpsThroughputTrendResponse
|
||||
errTrend *service.OpsErrorTrendResponse
|
||||
)
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.Go(func() error {
|
||||
f := *filter
|
||||
result, err := h.opsService.GetDashboardOverview(gctx, &f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
overview = result
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
f := *filter
|
||||
result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
trend = result
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
f := *filter
|
||||
result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
errTrend = result
|
||||
return nil
|
||||
})
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := &opsDashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
Overview: overview,
|
||||
ThroughputTrend: trend,
|
||||
ErrorTrend: errTrend,
|
||||
}
|
||||
|
||||
cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, resp)
|
||||
}
|
||||
@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
|
||||
}
|
||||
|
||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||
type CreateAndRedeemCodeRequest struct {
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
Notes string `json:"notes"`
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Code = strings.TrimSpace(req.Code)
|
||||
// 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。
|
||||
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
|
||||
if req.Type == "" {
|
||||
req.Type = "balance"
|
||||
}
|
||||
|
||||
if req.Type == "subscription" {
|
||||
if req.GroupID == nil {
|
||||
response.BadRequest(c, "group_id is required for subscription type")
|
||||
return
|
||||
}
|
||||
if req.ValidityDays <= 0 {
|
||||
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
}
|
||||
|
||||
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if createErr != nil {
|
||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||
|
||||
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
|
||||
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
|
||||
// parameter-validation layer that runs before any service call.
|
||||
func newCreateAndRedeemHandler() *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
adminService: newStubAdminService(),
|
||||
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
|
||||
}
|
||||
}
|
||||
|
||||
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
|
||||
// status code. For cases that pass validation and proceed into the service layer,
|
||||
// a panic may occur (because RedeemService internals are nil); this is expected
|
||||
// and treated as "validation passed" (returns 0 to indicate panic).
|
||||
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
jsonBytes, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Panic means we passed validation and entered service layer (expected for minimal stub).
|
||||
code = 0
|
||||
}
|
||||
}()
|
||||
handler.CreateAndRedeem(c)
|
||||
return w.Code
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
|
||||
// 不传 type 字段时应默认 balance,不触发 subscription 校验。
|
||||
// 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-default",
|
||||
"value": 10.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"omitting type should default to balance and pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-no-group",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"validity_days": 30,
|
||||
// group_id 缺失
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
validityDays int
|
||||
}{
|
||||
{"zero", 0},
|
||||
{"negative", -1},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-bad-days-" + tc.name,
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": tc.validityDays,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-valid",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": 31,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"valid subscription params should pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
// balance 类型不传 group_id 和 validity_days,不应报 400
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-no-extras",
|
||||
"type": "balance",
|
||||
"value": 50.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"balance type should not require group_id or validity_days")
|
||||
}
|
||||
163
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
163
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ScheduledTestHandler handles admin scheduled-test-plan management.
|
||||
type ScheduledTestHandler struct {
|
||||
scheduledTestSvc *service.ScheduledTestService
|
||||
}
|
||||
|
||||
// NewScheduledTestHandler creates a new ScheduledTestHandler.
|
||||
func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler {
|
||||
return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc}
|
||||
}
|
||||
|
||||
type createScheduledTestPlanRequest struct {
|
||||
AccountID int64 `json:"account_id" binding:"required"`
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid account id")
|
||||
return
|
||||
}
|
||||
|
||||
plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, plans)
|
||||
}
|
||||
|
||||
// Create POST /admin/scheduled-test-plans
|
||||
func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
var req createScheduledTestPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
plan := &service.ScheduledTestPlan{
|
||||
AccountID: req.AccountID,
|
||||
ModelID: req.ModelID,
|
||||
CronExpression: req.CronExpression,
|
||||
Enabled: true,
|
||||
MaxResults: req.MaxResults,
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
plan.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// Update PUT /admin/scheduled-test-plans/:id
|
||||
func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "plan not found")
|
||||
return
|
||||
}
|
||||
|
||||
var req updateScheduledTestPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelID != "" {
|
||||
existing.ModelID = req.ModelID
|
||||
}
|
||||
if req.CronExpression != "" {
|
||||
existing.CronExpression = req.CronExpression
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
existing.Enabled = *req.Enabled
|
||||
}
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
existing.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
// Delete DELETE /admin/scheduled-test-plans/:id
|
||||
func (h *ScheduledTestHandler) Delete(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
|
||||
}
|
||||
|
||||
// ListResults GET /admin/scheduled-test-plans/:id/results
|
||||
func (h *ScheduledTestHandler) ListResults(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
|
||||
results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit)
|
||||
if err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, results)
|
||||
}
|
||||
@@ -77,8 +77,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
FrontendURL: settings.FrontendURL,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -124,18 +126,21 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
@@ -197,6 +202,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -320,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Frontend URL 验证
|
||||
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
|
||||
if req.FrontendURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
|
||||
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义菜单项验证
|
||||
const (
|
||||
maxCustomMenuItems = 20
|
||||
@@ -426,50 +443,53 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -520,8 +540,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
FrontendURL: updatedSettings.FrontendURL,
|
||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -567,6 +589,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -598,9 +621,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||
changed = append(changed, "email_verify_enabled")
|
||||
}
|
||||
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
|
||||
changed = append(changed, "registration_email_suffix_whitelist")
|
||||
}
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.FrontendURL != after.FrontendURL {
|
||||
changed = append(changed, "frontend_url")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
@@ -718,6 +747,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||
changed = append(changed, "backend_mode_enabled")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
@@ -747,6 +779,18 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
|
||||
return normalized
|
||||
}
|
||||
|
||||
func equalStringSlice(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
@@ -800,7 +844,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
|
||||
err := h.emailService.TestSMTPConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -886,7 +930,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
`
|
||||
|
||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
response.BadRequest(c, "Failed to send test email: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1329,6 +1373,118 @@ func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// GetRectifierSettings 获取请求整流器配置
|
||||
// GET /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: settings.Enabled,
|
||||
ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRectifierSettingsRequest 更新整流器配置请求
|
||||
type UpdateRectifierSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// UpdateRectifierSettings 更新请求整流器配置
|
||||
// PUT /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
||||
var req UpdateRectifierSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.RectifierSettings{
|
||||
Enabled: req.Enabled,
|
||||
ThinkingSignatureEnabled: req.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: req.ThinkingBudgetEnabled,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// GetBetaPolicySettings 获取 Beta 策略配置
|
||||
// GET /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]dto.BetaPolicyRule, len(settings.Rules))
|
||||
for i, r := range settings.Rules {
|
||||
rules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: rules})
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求
|
||||
type UpdateBetaPolicySettingsRequest struct {
|
||||
Rules []dto.BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettings 更新 Beta 策略配置
|
||||
// PUT /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) {
|
||||
var req UpdateBetaPolicySettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]service.BetaPolicyRule, len(req.Rules))
|
||||
for i, r := range req.Rules {
|
||||
rules[i] = service.BetaPolicyRule(r)
|
||||
}
|
||||
|
||||
settings := &service.BetaPolicySettings{Rules: rules}
|
||||
if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Re-fetch to return updated settings
|
||||
updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outRules := make([]dto.BetaPolicyRule, len(updated.Rules))
|
||||
for i, r := range updated.Rules {
|
||||
outRules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: outRules})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
138
backend/internal/handler/admin/snapshot_cache.go
Normal file
138
backend/internal/handler/admin/snapshot_cache.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
type snapshotCacheEntry struct {
|
||||
ETag string
|
||||
Payload any
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type snapshotCache struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
items map[string]snapshotCacheEntry
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
type snapshotCacheLoadResult struct {
|
||||
Entry snapshotCacheEntry
|
||||
Hit bool
|
||||
}
|
||||
|
||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||
if ttl <= 0 {
|
||||
ttl = 30 * time.Second
|
||||
}
|
||||
return &snapshotCache{
|
||||
ttl: ttl,
|
||||
items: make(map[string]snapshotCacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) {
|
||||
if c == nil || key == "" {
|
||||
return snapshotCacheEntry{}, false
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
c.mu.RLock()
|
||||
entry, ok := c.items[key]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return snapshotCacheEntry{}, false
|
||||
}
|
||||
if now.After(entry.ExpiresAt) {
|
||||
c.mu.Lock()
|
||||
delete(c.items, key)
|
||||
c.mu.Unlock()
|
||||
return snapshotCacheEntry{}, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
||||
if c == nil {
|
||||
return snapshotCacheEntry{}
|
||||
}
|
||||
entry := snapshotCacheEntry{
|
||||
ETag: buildETagFromAny(payload),
|
||||
Payload: payload,
|
||||
ExpiresAt: time.Now().Add(c.ttl),
|
||||
}
|
||||
if key == "" {
|
||||
return entry
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.items[key] = entry
|
||||
c.mu.Unlock()
|
||||
return entry
|
||||
}
|
||||
|
||||
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
|
||||
if load == nil {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return entry, true, nil
|
||||
}
|
||||
if c == nil || key == "" {
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
return c.Set(key, payload), false, nil
|
||||
}
|
||||
|
||||
value, err, _ := c.sf.Do(key, func() (any, error) {
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
|
||||
}
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
result, ok := value.(snapshotCacheLoadResult)
|
||||
if !ok {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
return result.Entry, result.Hit, nil
|
||||
}
|
||||
|
||||
func buildETagFromAny(payload any) string {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(raw)
|
||||
return "\"" + hex.EncodeToString(sum[:]) + "\""
|
||||
}
|
||||
|
||||
func parseBoolQueryWithDefault(raw string, def bool) bool {
|
||||
value := strings.TrimSpace(strings.ToLower(raw))
|
||||
if value == "" {
|
||||
return def
|
||||
}
|
||||
switch value {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return def
|
||||
}
|
||||
}
|
||||
185
backend/internal/handler/admin/snapshot_cache_test.go
Normal file
185
backend/internal/handler/admin/snapshot_cache_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSnapshotCache_SetAndGet(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
|
||||
entry := c.Set("key1", map[string]string{"hello": "world"})
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
require.NotNil(t, entry.Payload)
|
||||
|
||||
got, ok := c.Get("key1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, entry.ETag, got.ETag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_Expiration(t *testing.T) {
|
||||
c := newSnapshotCache(1 * time.Millisecond)
|
||||
|
||||
c.Set("key1", "value")
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
_, ok := c.Get("key1")
|
||||
require.False(t, ok, "expired entry should not be returned")
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetEmptyKey(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
_, ok := c.Get("")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetMiss(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
_, ok := c.Get("nonexistent")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_NilReceiver(t *testing.T) {
|
||||
var c *snapshotCache
|
||||
_, ok := c.Get("key")
|
||||
require.False(t, ok)
|
||||
|
||||
entry := c.Set("key", "value")
|
||||
require.Empty(t, entry.ETag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_SetEmptyKey(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
|
||||
// Set with empty key should return entry but not store it
|
||||
entry := c.Set("", "value")
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
|
||||
_, ok := c.Get("")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_DefaultTTL(t *testing.T) {
|
||||
c := newSnapshotCache(0)
|
||||
require.Equal(t, 30*time.Second, c.ttl)
|
||||
|
||||
c2 := newSnapshotCache(-1 * time.Second)
|
||||
require.Equal(t, 30*time.Second, c2.ttl)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_ETagDeterministic(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
payload := map[string]int{"a": 1, "b": 2}
|
||||
|
||||
entry1 := c.Set("k1", payload)
|
||||
entry2 := c.Set("k2", payload)
|
||||
require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag")
|
||||
}
|
||||
|
||||
func TestSnapshotCache_ETagFormat(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
entry := c.Set("k", "test")
|
||||
// ETag should be quoted hex string: "abcdef..."
|
||||
require.True(t, len(entry.ETag) > 2)
|
||||
require.Equal(t, byte('"'), entry.ETag[0])
|
||||
require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1])
|
||||
}
|
||||
|
||||
func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
||||
// channels are not JSON-serializable
|
||||
etag := buildETagFromAny(make(chan int))
|
||||
require.Empty(t, etag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
|
||||
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"hello": "world"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, hit)
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
|
||||
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"unexpected": "value"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, hit)
|
||||
require.Equal(t, entry.ETag, entry2.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
start := make(chan struct{})
|
||||
const callers = 8
|
||||
errCh := make(chan error, callers)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
for range callers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
_, _, err := c.GetOrLoad("shared", func() (any, error) {
|
||||
loads.Add(1)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return "value", nil
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
def bool
|
||||
want bool
|
||||
}{
|
||||
{"empty returns default true", "", true, true},
|
||||
{"empty returns default false", "", false, false},
|
||||
{"1", "1", false, true},
|
||||
{"true", "true", false, true},
|
||||
{"TRUE", "TRUE", false, true},
|
||||
{"yes", "yes", false, true},
|
||||
{"on", "on", false, true},
|
||||
{"0", "0", true, false},
|
||||
{"false", "false", true, false},
|
||||
{"FALSE", "FALSE", true, false},
|
||||
{"no", "no", true, false},
|
||||
{"off", "off", true, false},
|
||||
{"whitespace trimmed", " true ", false, true},
|
||||
{"unknown returns default true", "maybe", true, true},
|
||||
{"unknown returns default false", "maybe", false, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := parseBoolQueryWithDefault(tc.raw, tc.def)
|
||||
require.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -216,6 +216,38 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||
type ResetSubscriptionQuotaRequest struct {
|
||||
Daily bool `json:"daily"`
|
||||
Weekly bool `json:"weekly"`
|
||||
Monthly bool `json:"monthly"`
|
||||
}
|
||||
|
||||
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
|
||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
var req ResetSubscriptionQuotaRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Daily && !req.Weekly && !req.Monthly {
|
||||
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
|
||||
return
|
||||
}
|
||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
// DELETE /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
|
||||
@@ -61,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct {
|
||||
// GET /api/v1/admin/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
exactTotal := false
|
||||
if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" {
|
||||
parsed, err := strconv.ParseBool(exactTotalRaw)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid exact_total value, use true or false")
|
||||
return
|
||||
}
|
||||
exactTotal = parsed
|
||||
}
|
||||
|
||||
// Parse filters
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
@@ -167,6 +176,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
ExactTotal: exactTotal,
|
||||
}
|
||||
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
|
||||
@@ -80,6 +80,29 @@ func TestAdminUsageListInvalidStream(t *testing.T) {
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageListExactTotalTrue(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.True(t, repo.listFilters.ExactTotal)
|
||||
}
|
||||
|
||||
func TestAdminUsageListInvalidExactTotal(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct {
|
||||
Attributes map[int64]map[int64]string `json:"attributes"`
|
||||
}
|
||||
|
||||
var userAttributesBatchCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
// AttributeDefinitionResponse represents attribute definition response
|
||||
type AttributeDefinitionResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
@@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
userIDs := normalizeInt64IDList(req.UserIDs)
|
||||
if len(userIDs) == 0 {
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
|
||||
return
|
||||
}
|
||||
|
||||
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
UserIDs []int64 `json:"user_ids"`
|
||||
}{
|
||||
UserIDs: userIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := userAttributesBatchCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
|
||||
payload := BatchUserAttributesResponse{Attributes: attrs}
|
||||
userAttributesBatchCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
@@ -91,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
Search: search,
|
||||
Attributes: parseAttributeFilters(c),
|
||||
}
|
||||
if raw, ok := c.GetQuery("include_subscriptions"); ok {
|
||||
includeSubscriptions := parseBoolQueryWithDefault(raw, true)
|
||||
filters.IncludeSubscriptions = &includeSubscriptions
|
||||
}
|
||||
|
||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -73,7 +74,23 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
|
||||
// Parse filter parameters
|
||||
var filters service.APIKeyListFilters
|
||||
if search := strings.TrimSpace(c.Query("search")); search != "" {
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
filters.Search = search
|
||||
}
|
||||
filters.Status = c.Query("status")
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
gid, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err == nil {
|
||||
filters.GroupID = &gid
|
||||
}
|
||||
}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
// Get the user (before session deletion so we can check backend mode)
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session (only after all checks pass)
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
||||
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
|
||||
if frontendBaseURL == "" {
|
||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
||||
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
|
||||
response.InternalError(c, "Password reset is not configured")
|
||||
return
|
||||
}
|
||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: block non-admin token refresh
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
|
||||
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
|
||||
if tokenErr != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
|
||||
return
|
||||
}
|
||||
fragment := url.Values{}
|
||||
fragment.Set("error", "invitation_required")
|
||||
fragment.Set("pending_oauth_token", pendingToken)
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
return
|
||||
}
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
@@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
type completeLinuxDoOAuthRequest struct {
|
||||
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
}
|
||||
|
||||
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
|
||||
// the invitation code and creating the user account.
|
||||
// POST /api/v1/auth/oauth/linuxdo/complete-registration
|
||||
func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
var req completeLinuxDoOAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
)
|
||||
|
||||
type Announcement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
NotifyMode string `json:"notify_mode"`
|
||||
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
|
||||
@@ -25,9 +26,10 @@ type Announcement struct {
|
||||
}
|
||||
|
||||
type UserAnnouncement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
NotifyMode string `json:"notify_mode"`
|
||||
|
||||
StartsAt *time.Time `json:"starts_at,omitempty"`
|
||||
EndsAt *time.Time `json:"ends_at,omitempty"`
|
||||
@@ -43,17 +45,18 @@ func AnnouncementFromService(a *service.Announcement) *Announcement {
|
||||
return nil
|
||||
}
|
||||
return &Announcement{
|
||||
ID: a.ID,
|
||||
Title: a.Title,
|
||||
Content: a.Content,
|
||||
Status: a.Status,
|
||||
Targeting: a.Targeting,
|
||||
StartsAt: a.StartsAt,
|
||||
EndsAt: a.EndsAt,
|
||||
CreatedBy: a.CreatedBy,
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
ID: a.ID,
|
||||
Title: a.Title,
|
||||
Content: a.Content,
|
||||
Status: a.Status,
|
||||
NotifyMode: a.NotifyMode,
|
||||
Targeting: a.Targeting,
|
||||
StartsAt: a.StartsAt,
|
||||
EndsAt: a.EndsAt,
|
||||
CreatedBy: a.CreatedBy,
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,13 +65,14 @@ func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement
|
||||
return nil
|
||||
}
|
||||
return &UserAnnouncement{
|
||||
ID: a.Announcement.ID,
|
||||
Title: a.Announcement.Title,
|
||||
Content: a.Announcement.Content,
|
||||
StartsAt: a.Announcement.StartsAt,
|
||||
EndsAt: a.Announcement.EndsAt,
|
||||
ReadAt: a.ReadAt,
|
||||
CreatedAt: a.Announcement.CreatedAt,
|
||||
UpdatedAt: a.Announcement.UpdatedAt,
|
||||
ID: a.Announcement.ID,
|
||||
Title: a.Announcement.Title,
|
||||
Content: a.Announcement.Content,
|
||||
NotifyMode: a.Announcement.NotifyMode,
|
||||
StartsAt: a.Announcement.StartsAt,
|
||||
EndsAt: a.Announcement.EndsAt,
|
||||
ReadAt: a.ReadAt,
|
||||
CreatedAt: a.Announcement.CreatedAt,
|
||||
UpdatedAt: a.Announcement.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
out := &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
@@ -89,15 +89,28 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
RateLimit5h: k.RateLimit5h,
|
||||
RateLimit1d: k.RateLimit1d,
|
||||
RateLimit7d: k.RateLimit7d,
|
||||
Usage5h: k.Usage5h,
|
||||
Usage1d: k.Usage1d,
|
||||
Usage7d: k.Usage7d,
|
||||
Usage5h: k.EffectiveUsage5h(),
|
||||
Usage1d: k.EffectiveUsage1d(),
|
||||
Usage7d: k.EffectiveUsage7d(),
|
||||
Window5hStart: k.Window5hStart,
|
||||
Window1dStart: k.Window1dStart,
|
||||
Window7dStart: k.Window7dStart,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
if k.Window5hStart != nil && !service.IsWindowExpired(k.Window5hStart, service.RateLimitWindow5h) {
|
||||
t := k.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
out.Reset5hAt = &t
|
||||
}
|
||||
if k.Window1dStart != nil && !service.IsWindowExpired(k.Window1dStart, service.RateLimitWindow1d) {
|
||||
t := k.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
out.Reset1dAt = &t
|
||||
}
|
||||
if k.Window7dStart != nil && !service.IsWindowExpired(k.Window7dStart, service.RateLimitWindow7d) {
|
||||
t := k.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
out.Reset7dAt = &t
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
@@ -126,6 +139,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
@@ -164,6 +178,7 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
@@ -183,6 +198,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
LoadFactor: a.LoadFactor,
|
||||
Priority: a.Priority,
|
||||
RateMultiplier: a.BillingRateMultiplier(),
|
||||
Status: a.Status,
|
||||
@@ -248,6 +264,50 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||
if a.IsAPIKeyOrBedrock() {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
used := a.GetQuotaUsed()
|
||||
out.QuotaUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
out.QuotaDailyLimit = &limit
|
||||
used := a.GetQuotaDailyUsed()
|
||||
out.QuotaDailyUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
out.QuotaWeeklyLimit = &limit
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
// 固定时间重置配置
|
||||
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||
out.QuotaDailyResetMode = &mode
|
||||
hour := a.GetQuotaDailyResetHour()
|
||||
out.QuotaDailyResetHour = &hour
|
||||
}
|
||||
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||
out.QuotaWeeklyResetMode = &mode
|
||||
day := a.GetQuotaWeeklyResetDay()
|
||||
out.QuotaWeeklyResetDay = &day
|
||||
hour := a.GetQuotaWeeklyResetHour()
|
||||
out.QuotaWeeklyResetHour = &hour
|
||||
}
|
||||
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
tz := a.GetQuotaResetTimezone()
|
||||
out.QuotaResetTimezone = &tz
|
||||
}
|
||||
if a.Extra != nil {
|
||||
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaDailyResetAt = &v
|
||||
}
|
||||
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaWeeklyResetAt = &v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -461,7 +521,10 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
InboundEndpoint: l.InboundEndpoint,
|
||||
UpstreamEndpoint: l.UpstreamEndpoint,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
InputTokens: l.InputTokens,
|
||||
|
||||
@@ -71,3 +71,41 @@ func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, requestTypeStringPtr(nil))
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
inboundEndpoint := "/v1/chat/completions"
|
||||
upstreamEndpoint := "/v1/responses"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
InboundEndpoint: &inboundEndpoint,
|
||||
UpstreamEndpoint: &upstreamEndpoint,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, userDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
|
||||
require.NotNil(t, userDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
|
||||
require.NotNil(t, adminDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -17,13 +17,15 @@ type CustomMenuItem struct {
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
@@ -80,6 +82,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -88,28 +93,30 @@ type DefaultSubscriptionSetting struct {
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
Version string `json:"version"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
||||
@@ -159,6 +166,26 @@ type StreamTimeoutSettings struct {
|
||||
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||
}
|
||||
|
||||
// RectifierSettings 请求整流器配置 DTO
|
||||
type RectifierSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// BetaPolicyRule Beta 策略规则 DTO
|
||||
type BetaPolicyRule struct {
|
||||
BetaToken string `json:"beta_token"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// BetaPolicySettings Beta 策略配置 DTO
|
||||
type BetaPolicySettings struct {
|
||||
Rules []BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@@ -57,6 +57,9 @@ type APIKey struct {
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
Reset5hAt *time.Time `json:"reset_5h_at,omitempty"`
|
||||
Reset1dAt *time.Time `json:"reset_1d_at,omitempty"`
|
||||
Reset7dAt *time.Time `json:"reset_7d_at,omitempty"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
@@ -96,6 +99,9 @@ type Group struct {
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
|
||||
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -112,6 +118,9 @@ type AdminGroup struct {
|
||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
@@ -131,6 +140,7 @@ type Account struct {
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status"`
|
||||
@@ -185,6 +195,24 @@ type Account struct {
|
||||
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaDailyLimit *float64 `json:"quota_daily_limit,omitempty"`
|
||||
QuotaDailyUsed *float64 `json:"quota_daily_used,omitempty"`
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
// 配额固定时间重置配置
|
||||
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -304,9 +332,15 @@ type UsageLog struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level.
|
||||
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
|
||||
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
SubscriptionID *int64 `json:"subscription_id"`
|
||||
|
||||
@@ -30,7 +30,7 @@ const (
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
maxSameAccountRetries = 3
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||
|
||||
@@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
require.Less(t, elapsed, 2*time.Second)
|
||||
})
|
||||
|
||||
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||
t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
for i := 1; i <= maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, i, fs.SameAccountRetryCount[100])
|
||||
}
|
||||
|
||||
// 第二次
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||
require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||
t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次、第二次重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||
// 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
@@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 耗尽账号 100 的重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
// 第三次: 重试耗尽 → 切换
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
// 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||
// 再次遇到账号 100,计数仍为 maxSameAccountRetries,条件不满足 → 直接切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||
@@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(502, true, false)
|
||||
|
||||
// 耗尽重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
}
|
||||
// 再次触发时才会执行 TempUnschedule + 切换
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
|
||||
require.Len(t, mock.calls, 1)
|
||||
@@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries 次
|
||||
retryErr := newTestFailoverErr(400, true, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
}
|
||||
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
// 2. 账号 100 超过重试上限 → TempUnschedule + 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Len(t, mock.calls, 1)
|
||||
|
||||
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -434,19 +441,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -635,6 +648,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -652,6 +667,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
// Beta policy block: return 400 immediately, no failover
|
||||
var betaBlockedErr *service.BetaBlockedError
|
||||
if errors.As(err, &betaBlockedErr) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
|
||||
return
|
||||
}
|
||||
|
||||
var promptTooLongErr *service.PromptTooLongError
|
||||
if errors.As(err, &promptTooLongErr) {
|
||||
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
||||
@@ -697,6 +719,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -729,19 +756,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -971,34 +1004,46 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
if err == nil && rateLimitData != nil {
|
||||
var rateLimits []gin.H
|
||||
if apiKey.RateLimit5h > 0 {
|
||||
used := rateLimitData.Usage5h
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
used := rateLimitData.EffectiveUsage5h()
|
||||
entry := gin.H{
|
||||
"window": "5h",
|
||||
"limit": apiKey.RateLimit5h,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit5h-used),
|
||||
"window_start": rateLimitData.Window5hStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) {
|
||||
entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 {
|
||||
used := rateLimitData.Usage1d
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
used := rateLimitData.EffectiveUsage1d()
|
||||
entry := gin.H{
|
||||
"window": "1d",
|
||||
"limit": apiKey.RateLimit1d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit1d-used),
|
||||
"window_start": rateLimitData.Window1dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) {
|
||||
entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 {
|
||||
used := rateLimitData.Usage7d
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
used := rateLimitData.EffectiveUsage7d()
|
||||
entry := gin.H{
|
||||
"window": "7d",
|
||||
"limit": apiKey.RateLimit7d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit7d-used),
|
||||
"window_start": rateLimitData.Window7dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) {
|
||||
entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if len(rateLimits) > 0 {
|
||||
resp["rate_limits"] = rateLimits
|
||||
|
||||
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
|
||||
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
|
||||
|
||||
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
|
||||
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
|
||||
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
|
||||
// 具体验证:
|
||||
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
|
||||
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
|
||||
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
|
||||
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)")
|
||||
|
||||
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
|
||||
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
|
||||
|
||||
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
|
||||
}
|
||||
|
||||
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
|
||||
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
|
||||
|
||||
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
|
||||
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
|
||||
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)")
|
||||
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
|
||||
|
||||
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx,
|
||||
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
|
||||
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
|
||||
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
|
||||
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
|
||||
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
}
|
||||
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
require.Contains(t, body, "event: message_start")
|
||||
require.Contains(t, body, `"type":"error"`)
|
||||
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
|
||||
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
|
||||
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
|
||||
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 模拟 writerSizeBeforeForward:初始为 -1
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
|
||||
// c.Writer.Size() 仍为 -1
|
||||
|
||||
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
|
||||
guardTriggered := c.Writer.Size() != sizeBeforeForward
|
||||
require.False(t, guardTriggered,
|
||||
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续")
|
||||
}
|
||||
@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
t.Helper()
|
||||
@@ -138,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // usageBillingRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
@@ -155,6 +157,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // sessionLimitCache
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
|
||||
@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
|
||||
@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
|
||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
Backup *admin.BackupHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
@@ -27,6 +28,7 @@ type AdminHandlers struct {
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
ScheduledTest *admin.ScheduledTestHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
286
backend/internal/handler/openai_chat_completions.go
Normal file
286
backend/internal/handler/openai_chat_completions.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||
// POST /v1/chat/completions
|
||||
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
c.Set("openai_chat_completions_fallback_model", "")
|
||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointChatCompletions),
|
||||
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_chat_completions.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var handlerStructuredLogCaptureMu sync.Mutex
|
||||
|
||||
type handlerInMemoryLogSink struct {
|
||||
mu sync.Mutex
|
||||
events []*logger.LogEvent
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||
if event == nil {
|
||||
return
|
||||
}
|
||||
cloned := *event
|
||||
if event.Fields != nil {
|
||||
cloned.Fields = make(map[string]any, len(event.Fields))
|
||||
for k, v := range event.Fields {
|
||||
cloned.Fields[k] = v
|
||||
}
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.events = append(s.events, &cloned)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
wantLevel := strings.ToLower(strings.TrimSpace(level))
|
||||
for _, ev := range s.events {
|
||||
if ev == nil {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, ev := range s.events {
|
||||
if ev == nil || ev.Fields == nil {
|
||||
continue
|
||||
}
|
||||
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) {
|
||||
t.Helper()
|
||||
handlerStructuredLogCaptureMu.Lock()
|
||||
|
||||
err := logger.Init(logger.InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: logger.SamplingOptions{Enabled: false},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sink := &handlerInMemoryLogSink{}
|
||||
logger.SetSink(sink)
|
||||
return sink, func() {
|
||||
logger.SetSink(nil)
|
||||
handlerStructuredLogCaptureMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOpenAIRemoteCompactPath(t *testing.T) {
|
||||
require.False(t, isOpenAIRemoteCompactPath(nil))
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil)
|
||||
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
require.False(t, isOpenAIRemoteCompactPath(c))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Set(opsModelKey, "gpt-5.3-codex")
|
||||
c.Set(opsAccountIDKey, int64(123))
|
||||
c.Header("x-request-id", "rid-compact-ok")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond))
|
||||
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||
require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "200"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex"))
|
||||
require.True(t, logSink.ContainsFieldValue("account_id", "123"))
|
||||
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok"))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Status(http.StatusBadGateway)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "502"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/responses/compact"))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||
|
||||
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "401"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormalizedOpenAIUpstreamEndpoint(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
fallback string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "responses root maps to responses upstream",
|
||||
path: "/v1/responses",
|
||||
fallback: openAIUpstreamEndpointResponses,
|
||||
want: "/v1/responses",
|
||||
},
|
||||
{
|
||||
name: "responses compact keeps compact suffix",
|
||||
path: "/openai/v1/responses/compact",
|
||||
fallback: openAIUpstreamEndpointResponses,
|
||||
want: "/v1/responses/compact",
|
||||
},
|
||||
{
|
||||
name: "responses nested suffix preserved",
|
||||
path: "/openai/v1/responses/compact/detail",
|
||||
fallback: openAIUpstreamEndpointResponses,
|
||||
want: "/v1/responses/compact/detail",
|
||||
},
|
||||
{
|
||||
name: "non responses path uses fallback",
|
||||
path: "/v1/messages",
|
||||
fallback: openAIUpstreamEndpointResponses,
|
||||
want: "/v1/responses",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
||||
|
||||
got := normalizedOpenAIUpstreamEndpoint(c, tt.fallback)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -33,8 +34,16 @@ type OpenAIGatewayHandler struct {
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
const (
|
||||
openAIInboundEndpointResponses = "/v1/responses"
|
||||
openAIInboundEndpointMessages = "/v1/messages"
|
||||
openAIInboundEndpointChatCompletions = "/v1/chat/completions"
|
||||
openAIUpstreamEndpointResponses = "/v1/responses"
|
||||
)
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
@@ -61,6 +70,7 @@ func NewOpenAIGatewayHandler(
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +80,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
compactStartedAt := time.Now()
|
||||
defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt)
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
|
||||
requestStart := time.Now()
|
||||
@@ -114,6 +126,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
sessionHashBody := body
|
||||
if service.IsOpenAIResponsesCompactPathForTest(c) {
|
||||
if compactSeed := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()); compactSeed != "" {
|
||||
c.Set(service.OpenAICompactSessionSeedKeyForTest(), compactSeed)
|
||||
}
|
||||
normalizedCompactBody, normalizedCompact, compactErr := service.NormalizeOpenAICompactRequestBodyForTest(body)
|
||||
if compactErr != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to normalize compact request body")
|
||||
return
|
||||
}
|
||||
if normalizedCompact {
|
||||
body = normalizedCompactBody
|
||||
}
|
||||
}
|
||||
|
||||
// 校验请求体 JSON 合法性
|
||||
if !gjson.ValidBytes(body) {
|
||||
@@ -189,11 +215,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -241,6 +268,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
@@ -270,6 +298,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -301,6 +348,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
@@ -309,18 +359,22 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses),
|
||||
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
@@ -340,6 +394,431 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAIRemoteCompactPath(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
return false
|
||||
}
|
||||
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
|
||||
return strings.HasSuffix(normalizedPath, "/responses/compact")
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) {
|
||||
if !isOpenAIRemoteCompactPath(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
path string
|
||||
status int
|
||||
)
|
||||
if c != nil {
|
||||
if c.Request != nil {
|
||||
ctx = c.Request.Context()
|
||||
if c.Request.URL != nil {
|
||||
path = strings.TrimSpace(c.Request.URL.Path)
|
||||
}
|
||||
}
|
||||
if c.Writer != nil {
|
||||
status = c.Writer.Status()
|
||||
}
|
||||
}
|
||||
|
||||
outcome := "failed"
|
||||
if status >= 200 && status < 300 {
|
||||
outcome = "succeeded"
|
||||
}
|
||||
latencyMs := time.Since(startedAt).Milliseconds()
|
||||
if latencyMs < 0 {
|
||||
latencyMs = 0
|
||||
}
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
zap.Bool("remote_compact", true),
|
||||
zap.String("compact_outcome", outcome),
|
||||
zap.Int("status_code", status),
|
||||
zap.Int64("latency_ms", latencyMs),
|
||||
zap.String("path", path),
|
||||
zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI),
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" {
|
||||
fields = append(fields, zap.String("request_user_agent", userAgent))
|
||||
}
|
||||
if v, ok := c.Get(opsModelKey); ok {
|
||||
if model, ok := v.(string); ok && strings.TrimSpace(model) != "" {
|
||||
fields = append(fields, zap.String("request_model", strings.TrimSpace(model)))
|
||||
}
|
||||
}
|
||||
if v, ok := c.Get(opsAccountIDKey); ok {
|
||||
if accountID, ok := v.(int64); ok && accountID > 0 {
|
||||
fields = append(fields, zap.Int64("account_id", accountID))
|
||||
}
|
||||
}
|
||||
if c.Writer != nil {
|
||||
if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" {
|
||||
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||
} else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" {
|
||||
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log := logger.FromContext(ctx).With(fields...)
|
||||
if outcome == "succeeded" {
|
||||
log.Info("codex.remote_compact.succeeded")
|
||||
return
|
||||
}
|
||||
log.Warn("codex.remote_compact.failed")
|
||||
}
|
||||
|
||||
// Messages handles Anthropic Messages API requests routed to OpenAI platform.
|
||||
// POST /v1/messages (when group platform is OpenAI)
|
||||
func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverAnthropicMessagesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.anthropicErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.messages",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 检查分组是否允许 /v1/messages 调度
|
||||
if apiKey.Group != nil && !apiKey.Group.AllowMessagesDispatch {
|
||||
h.anthropicErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
"This group does not allow /v1/messages dispatch")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.anthropicErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
// Anthropic 格式的请求在 metadata.user_id 中携带 session 标识,
|
||||
// 而非 OpenAI 的 session_id/conversation_id headers。
|
||||
// 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。
|
||||
if sessionHash == "" || promptCacheKey == "" {
|
||||
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
|
||||
seed := reqModel + "-" + userID
|
||||
if promptCacheKey == "" {
|
||||
promptCacheKey = service.GenerateSessionUUID(seed)
|
||||
}
|
||||
if sessionHash == "" {
|
||||
sessionHash = service.DeriveSessionHashFromSeed(seed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代
|
||||
c.Set("openai_messages_fallback_model", "")
|
||||
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"", // no previous_response_id
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_messages.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_messages.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_messages_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleAnthropicFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_messages.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointMessages),
|
||||
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_messages.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_messages.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// anthropicErrorResponse writes an error in Anthropic Messages API format.
|
||||
func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// anthropicStreamingAwareError handles errors that may occur during streaming,
|
||||
// using Anthropic SSE error format.
|
||||
func (h *OpenAIGatewayHandler) anthropicStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
errPayload, _ := json.Marshal(gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errPayload) //nolint:errcheck
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
h.anthropicErrorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// handleAnthropicFailoverExhausted maps upstream failover errors to Anthropic format.
|
||||
func (h *OpenAIGatewayHandler) handleAnthropicFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(failoverErr.StatusCode)
|
||||
h.anthropicStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
// ensureAnthropicErrorResponse writes a fallback Anthropic error if no response was written.
|
||||
func (h *OpenAIGatewayHandler) ensureAnthropicErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
return true
|
||||
@@ -756,17 +1235,23 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
if turnErr != nil || result == nil {
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses),
|
||||
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
@@ -817,6 +1302,26 @@ func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStart
|
||||
)
|
||||
}
|
||||
|
||||
// recoverAnthropicMessagesPanic recovers from panics in the Anthropic Messages
|
||||
// handler and returns an Anthropic-formatted error response.
|
||||
func (h *OpenAIGatewayHandler) recoverAnthropicMessagesPanic(c *gin.Context, streamStarted *bool) {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
|
||||
started := streamStarted != nil && *streamStarted
|
||||
requestLogger(c, "handler.openai_gateway.messages").Error(
|
||||
"openai.messages_panic_recovered",
|
||||
zap.Bool("stream_started", started),
|
||||
zap.Any("panic", recovered),
|
||||
zap.ByteString("stack", debug.Stack()),
|
||||
)
|
||||
if !started {
|
||||
h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
|
||||
missing := h.missingResponsesDependencies()
|
||||
if len(missing) == 0 {
|
||||
@@ -1022,6 +1527,14 @@ func setOpenAIClientTransportWS(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||
}
|
||||
|
||||
func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string {
|
||||
if sessionHash != "" || account == nil || !account.IsPoolMode() {
|
||||
return sessionHash
|
||||
}
|
||||
// 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。
|
||||
return "openai-pool-retry-" + uuid.NewString()
|
||||
}
|
||||
|
||||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||
gid := int64(0)
|
||||
if groupID != nil {
|
||||
@@ -1030,6 +1543,62 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64)
|
||||
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
|
||||
}
|
||||
|
||||
func normalizedOpenAIInboundEndpoint(c *gin.Context, fallback string) string {
|
||||
path := strings.TrimSpace(fallback)
|
||||
if c != nil {
|
||||
if fullPath := strings.TrimSpace(c.FullPath()); fullPath != "" {
|
||||
path = fullPath
|
||||
} else if c.Request != nil && c.Request.URL != nil {
|
||||
if requestPath := strings.TrimSpace(c.Request.URL.Path); requestPath != "" {
|
||||
path = requestPath
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.Contains(path, openAIInboundEndpointChatCompletions):
|
||||
return openAIInboundEndpointChatCompletions
|
||||
case strings.Contains(path, openAIInboundEndpointMessages):
|
||||
return openAIInboundEndpointMessages
|
||||
case strings.Contains(path, openAIInboundEndpointResponses):
|
||||
return openAIInboundEndpointResponses
|
||||
default:
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
func normalizedOpenAIUpstreamEndpoint(c *gin.Context, fallback string) string {
|
||||
base := strings.TrimSpace(fallback)
|
||||
if base == "" {
|
||||
base = openAIUpstreamEndpointResponses
|
||||
}
|
||||
base = strings.TrimRight(base, "/")
|
||||
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
return base
|
||||
}
|
||||
|
||||
path := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
|
||||
if path == "" {
|
||||
return base
|
||||
}
|
||||
|
||||
idx := strings.LastIndex(path, "/responses")
|
||||
if idx < 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
suffix := strings.TrimSpace(path[idx+len("/responses"):])
|
||||
if suffix == "" || suffix == "/" {
|
||||
return base
|
||||
}
|
||||
if !strings.HasPrefix(suffix, "/") {
|
||||
return base
|
||||
}
|
||||
|
||||
return base + suffix
|
||||
}
|
||||
|
||||
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
|
||||
@@ -26,11 +26,28 @@ const (
|
||||
opsStreamKey = "ops_stream"
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
|
||||
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||
opsErrContextCanceled = "context canceled"
|
||||
opsErrNoAvailableAccounts = "no available accounts"
|
||||
opsErrInvalidAPIKey = "invalid_api_key"
|
||||
opsErrAPIKeyRequired = "api_key_required"
|
||||
opsErrInsufficientBalance = "insufficient balance"
|
||||
opsErrInsufficientAccountBalance = "insufficient account balance"
|
||||
opsErrInsufficientQuota = "insufficient_quota"
|
||||
|
||||
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||
opsCodeUserInactive = "USER_INACTIVE"
|
||||
)
|
||||
|
||||
const (
|
||||
opsErrorLogTimeout = 5 * time.Second
|
||||
opsErrorLogDrainTimeout = 10 * time.Second
|
||||
opsErrorLogBatchWindow = 200 * time.Millisecond
|
||||
|
||||
opsErrorLogMinWorkerCount = 4
|
||||
opsErrorLogMaxWorkerCount = 32
|
||||
@@ -38,6 +55,7 @@ const (
|
||||
opsErrorLogQueueSizePerWorker = 128
|
||||
opsErrorLogMinQueueSize = 256
|
||||
opsErrorLogMaxQueueSize = 8192
|
||||
opsErrorLogBatchSize = 32
|
||||
)
|
||||
|
||||
type opsErrorLogJob struct {
|
||||
@@ -82,27 +100,82 @@ func startOpsErrorLogWorkers() {
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go func() {
|
||||
defer opsErrorLogWorkersWg.Done()
|
||||
for job := range opsErrorLogQueue {
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
for {
|
||||
job, ok := <-opsErrorLogQueue
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
|
||||
batch = append(batch, job)
|
||||
|
||||
timer := time.NewTimer(opsErrorLogBatchWindow)
|
||||
batchLoop:
|
||||
for len(batch) < opsErrorLogBatchSize {
|
||||
select {
|
||||
case nextJob, ok := <-opsErrorLogQueue:
|
||||
if !ok {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
return
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
||||
cancel()
|
||||
opsErrorLogProcessed.Add(1)
|
||||
}()
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch = append(batch, nextJob)
|
||||
case <-timer.C:
|
||||
break batchLoop
|
||||
}
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
|
||||
var processed int64
|
||||
for _, job := range batch {
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
}
|
||||
grouped[job.ops] = append(grouped[job.ops], job.entry)
|
||||
processed++
|
||||
}
|
||||
if processed == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for opsSvc, entries := range grouped {
|
||||
if opsSvc == nil || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = opsSvc.RecordErrorBatch(ctx, entries)
|
||||
cancel()
|
||||
}
|
||||
opsErrorLogProcessed.Add(processed)
|
||||
}
|
||||
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
@@ -967,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
|
||||
return errType
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE":
|
||||
case opsCodeInsufficientBalance:
|
||||
return "billing_error"
|
||||
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "subscription_error"
|
||||
default:
|
||||
return "api_error"
|
||||
@@ -981,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||
// Map billing/concurrency/response => request; scheduling => routing.
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "request"
|
||||
}
|
||||
|
||||
@@ -1000,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
case "upstream_error", "overloaded_error":
|
||||
return "upstream"
|
||||
case "api_error":
|
||||
if strings.Contains(msg, "no available accounts") {
|
||||
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||
return "routing"
|
||||
}
|
||||
return "internal"
|
||||
@@ -1046,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||
return true
|
||||
}
|
||||
if phase == "billing" || phase == "concurrency" {
|
||||
@@ -1140,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
||||
|
||||
// Check if context canceled errors should be ignored (client disconnects)
|
||||
if settings.IgnoreContextCanceled {
|
||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
||||
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if "no available accounts" errors should be ignored
|
||||
if settings.IgnoreNoAvailableAccounts {
|
||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
||||
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||
if settings.IgnoreInvalidApiKeyErrors {
|
||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
||||
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if insufficient balance errors should be ignored
|
||||
if settings.IgnoreInsufficientBalanceErrors {
|
||||
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
|
||||
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
|
||||
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,27 +32,29 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
Version: h.version,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -996,7 +996,7 @@ func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*se
|
||||
}
|
||||
func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
|
||||
func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
|
||||
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||
@@ -2132,6 +2132,14 @@ func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
|
||||
|
||||
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
|
||||
@@ -2198,8 +2206,8 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
|
||||
@@ -216,6 +216,14 @@ func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulable() []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
@@ -326,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -335,6 +351,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -423,6 +442,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
testutil.StubGatewayCache{},
|
||||
cfg,
|
||||
nil,
|
||||
@@ -437,6 +457,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
testutil.StubSessionLimitCache{},
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
|
||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
backupHandler *admin.BackupHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -30,6 +31,7 @@ func ProvideAdminHandlers(
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
errorPassthroughHandler *admin.ErrorPassthroughHandler,
|
||||
apiKeyHandler *admin.AdminAPIKeyHandler,
|
||||
scheduledTestHandler *admin.ScheduledTestHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -38,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
Backup: backupHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -53,6 +56,7 @@ func ProvideAdminHandlers(
|
||||
UserAttribute: userAttributeHandler,
|
||||
ErrorPassthrough: errorPassthroughHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
ScheduledTest: scheduledTestHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewBackupHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
@@ -141,6 +146,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewUserAttributeHandler,
|
||||
admin.NewErrorPassthroughHandler,
|
||||
admin.NewAdminAPIKeyHandler,
|
||||
admin.NewScheduledTestHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
@@ -159,6 +159,8 @@ var claudeModels = []modelDef{
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
var geminiModels = []modelDef{
|
||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
|
||||
@@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
||||
|
||||
requiredIDs := []string{
|
||||
"claude-opus-4-6-thinking",
|
||||
"gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview",
|
||||
"gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview",
|
||||
"gemini-3-pro-image", // legacy compatibility
|
||||
|
||||
@@ -19,6 +19,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
)
|
||||
|
||||
// ForbiddenError 表示上游返回 403 Forbidden
|
||||
type ForbiddenError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *ForbiddenError) Error() string {
|
||||
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
@@ -514,7 +524,20 @@ type ModelQuotaInfo struct {
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||
}
|
||||
|
||||
// DeprecatedModelInfo 废弃模型转发信息
|
||||
type DeprecatedModelInfo struct {
|
||||
NewModelID string `json:"newModelId"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
@@ -524,7 +547,8 @@ type FetchAvailableModelsRequest struct {
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
@@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, nil, &ForbiddenError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBodyBytes),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user