Compare commits

..

73 Commits

Author SHA1 Message Date
shaw
0cce0a8877 chore: gpt-5.4示例配置修改model_reasoning_effort为xhigh 2026-03-06 11:29:43 +08:00
shaw
225fd035ae chore: 更新codex配置部分支持gpt-5.4的长上下文 2026-03-06 10:55:09 +08:00
Wesley Liddick
fb7d1346b5 Merge pull request #800 from mt21625457/pr/gpt54-support-upstream
feat(openai): 增加 GPT-5.4 支持并修复长上下文计费与白名单回归
2026-03-06 10:42:01 +08:00
shaw
491a744481 fix: 修复账号列表首次加载窗口费用显示 $0.00
lite 模式下从快照缓存读取窗口费用,非 lite 模式查询后写入缓存
2026-03-06 10:23:22 +08:00
yangjianbo
f366026435 fix(openai): 修复 gpt-5.4 长上下文计费与快照白名单
补齐 gpt-5.4 fallback 的长上下文计费元信息,\n确保超过 272000 输入 token 时对整次会话应用\n2x 输入与 1.5x 输出计费规则。\n\n同时将官方快照 gpt-5.4-2026-03-05 加入前端\n白名单候选与回归测试,避免 whitelist 模式误拦截。\n\nCo-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

(cherry picked from commit d95497af87f608c6dadcbe7d6e851de9413ae147)
2026-03-06 10:16:23 +08:00
yangjianbo
1a0d4ed668 feat(openai): 增加 gpt-5.4 模型支持与定价配置
- 接入 gpt-5.4 模型识别与规范化,补充默认模型列表
- 增加 gpt-5.4 输入/缓存命中/输出价格与计费兜底逻辑
- 同步前端模型白名单与 OpenCode 上下文窗口(1050000/128000)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
(cherry picked from commit 924476dcac6181cd0f3ee731ec7b73672ff03793)
2026-03-06 10:16:23 +08:00
Wesley Liddick
63a8c76946 Merge pull request #798 from touwaeriol/feature/account-load-factor
feat: add account load_factor for scheduling load calculation
2026-03-06 09:42:10 +08:00
Wesley Liddick
f355a68bc9 Merge pull request #796 from touwaeriol/feature/apikey-quota-limit
feat: add configurable spending limit for API Key accounts
2026-03-06 09:37:52 +08:00
erio
c87e6526c1 fix: use real Concurrency instead of LoadFactor for metrics and monitoring
LoadFactor should only affect scheduling weight, not load rate reporting.
2026-03-06 05:18:45 +08:00
erio
af3a5076d6 fix: add load_factor upper bound validation to BulkUpdateAccounts 2026-03-06 05:17:52 +08:00
erio
18f2e21414 fix: use HTML-safe expressions for @input handlers in Vue templates
Replace `<` comparisons with Math.max/ternary+>= to avoid Vue template
parser treating `<` as HTML tag start in attribute values.
2026-03-06 05:07:52 +08:00
erio
8a8cdeebb4 fix: prevent negative values for concurrency and load_factor inputs 2026-03-06 05:07:52 +08:00
erio
12b33f4ea4 fix: address load_factor code review findings
- Fix bulk edit: send 0 instead of null/NaN to clear load_factor
- Fix edit modal: explicit NaN check instead of implicit falsy
- Fix create modal: use ?? instead of || for load_factor
- Add load_factor upper limit validation (max 10000)
- Add //go:build unit tag and self-contained intPtrHelper in test
- Add design intent comments on WaitPlan.MaxConcurrency
2026-03-06 05:07:52 +08:00
erio
01b3a09d7d fix: validate account status before update and update load factor hint
- Normalize non-standard status (e.g. "error") to "active" on edit load
- Add pre-submit validation for status field to prevent 400 errors
- Update load factor hint: "提高负载因子可以提高对账号的调度频率"
2026-03-06 05:07:40 +08:00
erio
0d6c1c7790 feat: add independent load_factor field for scheduling load calculation 2026-03-06 05:07:10 +08:00
erio
95e366b6c6 fix: add missing IncrementQuotaUsed and ResetQuotaUsed to stubAccountRepo in api_contract_test 2026-03-06 04:37:56 +08:00
erio
77701143bf fix: use range assertion for time-sensitive ExpiresInDays test
The test could flake depending on exact execution time near midnight
boundaries. Use a range check (29 or 30) instead of exact equality.
2026-03-06 01:07:28 +08:00
erio
02dea7b09b refactor: unify post-usage billing logic and fix account quota calculation
- Extract postUsageBilling() to consolidate billing logic across
  GatewayService.RecordUsage, RecordUsageWithLongContext, and
  OpenAIGatewayService.RecordUsage, eliminating ~120 lines of
  duplicated code
- Fix account quota to use TotalCost × accountRateMultiplier
  (was using raw TotalCost, inconsistent with account cost stats)
- Fix RecordUsageWithLongContext API Key quota only updating in
  balance mode (now updates regardless of billing type)
- Fix WebSocket client disconnect detection on Windows by adding
  "an established connection was aborted" to known disconnect errors
2026-03-06 00:54:17 +08:00
erio
c26f93c4a0 fix: route antigravity apikey account test to native protocol
Antigravity APIKey accounts were incorrectly routed to
testAntigravityAccountConnection which calls AntigravityTokenProvider,
but the token provider only handles OAuth and Upstream types, causing
"not an antigravity oauth account" error.

Extract routeAntigravityTest to route APIKey accounts to native
Claude/Gemini test paths based on model prefix, matching the
gateway_handler routing logic for normal requests.
2026-03-06 00:36:13 +08:00
erio
c826ac28ef refactor: extract QuotaLimitCard component for reuse in create and edit modals
- Extract quota limit card/toggle UI into QuotaLimitCard.vue component
- Use v-model pattern for clean parent-child data flow
- Integrate into both EditAccountModal and CreateAccountModal
- All apikey accounts (all platforms) now support quota limit on creation
- Bump version to 0.1.90.6
2026-03-06 00:36:00 +08:00
erio
1893b0eb30 feat: restyle API Key quota limit UI to card/toggle format
- Redesign quota limit section with card layout and toggle switch
- Add watch to clear quota value when toggle is disabled
- Add i18n keys for toggle labels and hints (zh/en)
- Bump version to 0.1.90.5
2026-03-06 00:35:35 +08:00
erio
05527b13db feat: add quota limit for API key accounts
- Add configurable spending limit (quota_limit) for apikey-type accounts
- Atomic quota accumulation via PostgreSQL JSONB operations on TotalCost
- Scheduler filters out over-quota accounts with outbox-triggered snapshot refresh
- Display quota usage ($used / $limit) in account capacity column
- Add "Reset Quota" action in account menu to reset usage to zero
- Editing account settings preserves quota_used (no accidental reset)
- Covers all 3 billing paths: Anthropic, Gemini, OpenAI RecordUsage

chore: bump version to 0.1.90.4
2026-03-06 00:35:09 +08:00
Wesley Liddick
ae5d9c8bfc Merge pull request #788 from touwaeriol/fix/usage-error-passthrough
fix: pass through upstream HTTP status in usage API errors
2026-03-05 22:05:36 +08:00
erio
9117c2a4ec fix: include upstream error details in usage API error response
When FetchUsageWithOptions receives a non-200 response from the
Anthropic API (e.g. 429 Rate Limited, 401 Unauthorized), the error
was wrapped with fmt.Errorf which infraerrors.FromError cannot
recognize, causing a generic "internal error" message with no details.

Replace fmt.Errorf with infraerrors.New(500, "UPSTREAM_ERROR", msg)
so the upstream error details (status code + body) are included in
the 500 response message. The HTTP status remains 500 to avoid
interfering with frontend auth routing (e.g. 401 would trigger
JWT expiry redirect).
2026-03-05 21:02:11 +08:00
shaw
bab4bb9904 chore: 更新openai、claude使用秘钥教程部分 2026-03-05 18:58:10 +08:00
shaw
33bae6f49b fix: Cache Token拆分为缓存创建和缓存读取 2026-03-05 18:32:17 +08:00
Wesley Liddick
32d619a56b Merge pull request #780 from mt21625457/feat/codex-remote-compact-outcome-logging
feat(openai-handler): support codex remote compact outcome logging
2026-03-05 16:59:02 +08:00
Wesley Liddick
642432cf2a Merge pull request #777 from guoyongchang/feature-schedule-test-support
feat: 支持基于 crontab 的定时账号测试
2026-03-05 16:57:23 +08:00
程序猿MT
61e9598b08 fix(lint): remove redundant context type in compact outcome logger 2026-03-05 16:51:46 +08:00
guoyongchang
d4e34c7514 fix: 修复空结果导致定时测试模态框崩溃的问题
后端返回 null (Go nil slice) 时前端访问 .length 抛出 TypeError,
在 API 层对 listByAccount 和 listResults 加 ?? [] 兜底。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:47:01 +08:00
程序猿MT
bfe7a5e452 test(openai-handler): add codex remote compact outcome coverage 2026-03-05 16:46:14 +08:00
程序猿MT
77d916ffec feat(openai-handler): support codex remote compact outcome logging 2026-03-05 16:46:12 +08:00
guoyongchang
831abf7977 refactor: 移除冗余中间类型和不必要代码
- 移除 ScheduledTestOutcome 中间类型,RunTestBackground 直接返回 *ScheduledTestResult
- 简化 SaveResult 直接接受 *ScheduledTestResult
- 移除 handler 中不必要的 nil 检查
- 移除前端 ScheduledTestsPanel 中多余的 String() 转换

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:37:07 +08:00
guoyongchang
817a491087 simplify: 移除 leader lock,单实例无需分布式锁
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:31:27 +08:00
guoyongchang
9a8dacc514 fix: 修复 golangci-lint depguard 和 gofmt 错误
将 redis leader lock 逻辑从 service 层抽取为 LeaderLocker 接口,
实现移至 repository 层,消除 service 层对 redis 的直接依赖。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:28:48 +08:00
guoyongchang
8adf80d98b fix: wire_gen_test 补充 scheduledTestRunner 参数
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:23:41 +08:00
guoyongchang
62686a6213 revert: 还原 docker-compose.local.yml 的本地测试改动
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:17:33 +08:00
guoyongchang
3a089242f8 feat: 支持基于 crontab 的定时账号测试
每个测试计划绑定一个账号和一个模型,按 cron 表达式定期执行测试,
保存历史结果并在前端账号管理页面中提供完整的增删改查和结果查看功能。

主要变更:
- 新增 scheduled_test_plans / scheduled_test_results 两张表及迁移
- 后端 service 层:CRUD 服务 + 后台 cron runner(每分钟扫描到期计划并发执行)
- RunTestBackground 方法通过 httptest 在内存中执行账号测试并解析 SSE 输出
- Redis leader lock + pg_try_advisory_lock 双重保障多实例部署只执行一次
- REST API:5 个管理端点(计划 CRUD + 结果查询)
- 前端 ScheduledTestsPanel 组件:计划管理、启用开关、内联编辑、结果展开查看
- 中英文 i18n 支持

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:06:05 +08:00
shaw
9d70c38504 fix: 修复claude apikey账号请求时未携带beta=true 查询参数的bug 2026-03-05 15:01:04 +08:00
shaw
aeb464f3ca feat: 模型映射应用 /v1/messages/count_tokens端点 2026-03-05 14:49:28 +08:00
Wesley Liddick
7076717b20 Merge pull request #772 from mt21625457/aicodex2api-main
feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
2026-03-05 13:46:02 +08:00
程序猿MT
c0a4fcea0a Delete docker-compose-aicodex.yml
删除测试 docker compose文件
2026-03-05 13:44:07 +08:00
程序猿MT
aa2b195c86 Delete Caddyfile.dmit
删除测试caddy 配置文件
2026-03-05 13:43:25 +08:00
yangjianbo
1d0872e7ca feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
新增 OpenAI WebSocket v2 passthrough relay 数据面与服务适配层,
支持按账号 ws mode 在 ctx_pool 与 passthrough 间路由。

同步调整前端 OpenAI ws mode 选项为 off/ctx_pool/passthrough,
并补充 i18n 文案与对应单测。

新增 Caddyfile.dmit 与 docker-compose-aicodex.yml 部署配置,
用于宿主机场景下的反向代理与服务编排。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 11:50:58 +08:00
shaw
33988637b5 fix: SMTP测试连接和发送测试邮件返回具体错误信息而非internal error 2026-03-05 10:54:41 +08:00
shaw
d4f6ad7225 feat: 新增apikey的usage查询页面 2026-03-05 10:45:51 +08:00
shaw
078fefed03 fix: 修复账号管理页面容量列显示为0的bug 2026-03-05 09:48:00 +08:00
Wesley Liddick
5b10af85b4 Merge pull request #762 from touwaeriol/fix/dark-theme-open-in-new-tab
fix: add dark theme support for "open in new tab" FAB button
2026-03-05 08:56:28 +08:00
Wesley Liddick
4caf95e5dd Merge pull request #767 from litianc/fix/rewrite-userid-regex-match-account-uuid
fix: extend RewriteUserID regex to match user_id containing account_uuid
2026-03-05 08:56:03 +08:00
litianc
8e1bcf53bb fix: extend RewriteUserID regex to match user_id containing account_uuid
The existing regex only matched the old format where account_uuid is
empty (account__session_). Real Claude Code clients and newer sub2api
generated user_ids use account_{uuid}_session_ which was silently
skipped, causing the original metadata.user_id to leak to upstream
when User-Agent is rewritten by an intermediate gateway.

Closes #766

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-04 23:13:17 +08:00
erio
064f9be7e4 fix: add dark theme support for "open in new tab" FAB button
The backdrop-blur background on the iframe "open in new tab" floating
button was hardcoded to bg-white/80, making it look broken in dark
theme. Added dark:bg-dark-800/80 variant for both PurchaseSubscription
and CustomPage views.
2026-03-04 21:40:40 +08:00
Wesley Liddick
adcfb44cb7 Merge pull request #761 from james-6-23/main
feat: 修复 v0.1.89 OAuth 401 永久锁死账号问题,改用临时不可调度实现自动恢复;增强二次 401 自动升级为错误状态,添加 DB   回退确保生效;管理后台新增临时不可调度状态筛选
2026-03-04 21:11:24 +08:00
kyx236
3d79773ba2 Merge branch 'main' of https://github.com/james-6-23/sub2api 2026-03-04 20:25:39 +08:00
kyx236
6aa8cbbf20 feat: 二次 401 直接升级为错误状态,添加 DB 回退确保生效
账号首次 401 仅临时不可调度,给予 token 刷新窗口;若恢复后再次 401
说明凭证确实失效,直接升级为错误状态以避免反复无效调度。

- 缓存中 reason 为空时从 DB 回退读取,防止升级判断失效
- ClearError 同时清除临时不可调度状态,管理员恢复后重新给予一次机会
- 管理后台账号列表添加"临时不可调度"状态筛选
- 补充 DB 回退场景单元测试
2026-03-04 20:25:15 +08:00
shaw
742e73c9c2 fix: 优化充值/订阅菜单的icon 2026-03-04 17:24:09 +08:00
shaw
f8de2bdedc fix(frontend): settings页面分tab拆分 2026-03-04 16:59:57 +08:00
shaw
59879b7fa7 fix(i18n): replace hardcoded English strings in EmailVerifyView with i18n calls 2026-03-04 15:58:44 +08:00
Wesley Liddick
27abae21b8 Merge pull request #724 from PMExtra/feat/registration-email-domain-whitelist
feat(registration): add email domain whitelist policy
2026-03-04 15:51:51 +08:00
shaw
0819c8a51a refactor: 消除重复的 normalizeAccountIDList,补充 PR#754 新增组件的单元测试
- 删除 account_today_stats_cache.go 中重复的 normalizeAccountIDList,统一使用 id_list_utils.go 的 normalizeInt64IDList
- 新增 snapshot_cache_test.go:覆盖 snapshotCache、buildETagFromAny、parseBoolQueryWithDefault
- 新增 id_list_utils_test.go:覆盖 normalizeInt64IDList、buildAccountTodayStatsBatchCacheKey
- 新增 ops_query_mode_test.go:覆盖 shouldFallbackOpsPreagg、cloneOpsFilterWithMode
2026-03-04 15:22:46 +08:00
Wesley Liddick
9dcd3cd491 Merge pull request #754 from xvhuan/perf/admin-core-large-dataset
perf(admin): 优化后台大数据场景加载性能(仪表盘/用户/账号/Ops)
2026-03-04 15:15:13 +08:00
Wesley Liddick
49767cccd2 Merge pull request #755 from xvhuan/perf/admin-usage-fast-pagination-main
perf(admin-usage): 优化 usage 大表分页,默认避免全量 COUNT(*)
2026-03-04 14:15:57 +08:00
PMExtra
29fb447daa fix(frontend): remove unused variables 2026-03-04 14:12:08 +08:00
xvhuan
f6fe5b552d fix(admin): resolve CI lint and user subscriptions regression 2026-03-04 14:07:17 +08:00
PMExtra
bd0801a887 feat(registration): add email domain whitelist policy 2026-03-04 13:54:18 +08:00
xvhuan
05b1c66aa8 perf(admin-usage): avoid expensive count on large usage_logs pagination 2026-03-04 13:51:27 +08:00
xvhuan
80ae592c23 perf(admin): optimize large-dataset loading for dashboard/users/accounts/ops 2026-03-04 13:45:49 +08:00
shaw
ba6de4c4d4 feat: /keys页面支持表单筛选 2026-03-04 11:29:31 +08:00
shaw
46ea9170cb fix: 修复自定义菜单页面管理员视角菜单不生效问题 2026-03-04 10:44:28 +08:00
shaw
7d318aeefa fix: 恢复check_pnpm_audit_exceptions.py 2026-03-04 10:20:19 +08:00
shaw
0aa3cf677a chore: 清理一些无用的文件 2026-03-04 10:15:42 +08:00
shaw
72961c5858 fix: Anthropic 平台无限流重置时间的 429 不再误标记账号限流 2026-03-04 09:36:24 +08:00
Wesley Liddick
a05711a37a Merge pull request #742 from zqq-nuli/fix/ops-error-detail-upstream-payload
fix(frontend): show real upstream payload in ops error detail modal
2026-03-04 09:04:11 +08:00
zqq61
efc9e1d673 fix(frontend): prefer upstream payload for generic ops error body 2026-03-03 23:45:34 +08:00
206 changed files with 11566 additions and 5153 deletions

105
AGENTS.md
View File

@@ -1,105 +0,0 @@
# Repository Guidelines
## Project Structure & Module Organization
- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output.
- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`.
- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`).
- `openspec/`: Spec-driven change docs (`changes/<id>/{proposal,design,tasks}.md`).
- `tools/`: Utility scripts (security/perf checks).
## Build, Test, and Development Commands
```bash
make build # Build backend + frontend
make test # Backend tests + frontend lint/typecheck
cd backend && make build # Build backend binary
cd backend && make test-unit # Go unit tests
cd backend && make test-integration # Go integration tests
cd backend && make test # go test ./... + golangci-lint
cd frontend && pnpm install --frozen-lockfile
cd frontend && pnpm dev # Vite dev server
cd frontend && pnpm build # Type-check + production build
cd frontend && pnpm test:run # Vitest run
cd frontend && pnpm test:coverage # Vitest + coverage report
python3 tools/secret_scan.py # Secret scan
```
## Coding Style & Naming Conventions
- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`).
- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard).
- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`.
- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`.
## Go & Frontend Development Standards
- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears.
- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency.
- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable.
### Go Performance Rules
- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code.
- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable.
- Every external I/O path must use `context.Context` with explicit timeout/cancel.
- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources.
- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`).
- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`.
- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths.
- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after).
- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out.
- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention.
- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible.
- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary.
- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions.
- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain.
- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible.
- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled.
### Data Access & Cache Rules
- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR.
- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints.
- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths.
- Keep transactions short; never perform external RPC/network calls inside DB transactions.
- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`).
- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks.
- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd.
- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths.
### Frontend Performance Rules
- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components.
- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs.
- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it.
- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed.
- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows).
- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking.
- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review.
- Avoid expensive inline computations in templates; move to cached `computed` selectors.
- Keep state normalized; avoid duplicated derived state across multiple stores/components.
- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import.
- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`.
- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks.
- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes.
### Performance Budget & PR Evidence
- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline.
- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval.
- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table).
- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output).
### Quality Gate
- Any changed code must include new or updated unit tests.
- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules).
- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description.
## Testing Guidelines
- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant.
- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`.
- Enforce unit-test and coverage rules defined in `Quality Gate`.
- Before opening a PR, run `make test` plus targeted tests for touched areas.
## Commit & Pull Request Guidelines
- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`.
- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes.
- For behavior/API changes, add or update `openspec/changes/...` artifacts.
- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR.
## Security & Configuration Tips
- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials.
- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev.

View File

@@ -137,8 +137,6 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
如果你的服务器是 **Ubuntu 24.04**,建议直接参考:`deploy/ubuntu24-docker-compose-aicodex.md`,其中包含「安装最新版 Docker + docker-compose-aicodex.yml 部署」的完整步骤。
#### 前置条件
- Docker 20.10+

View File

@@ -86,6 +86,7 @@ func provideCleanup(
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -216,6 +217,12 @@ func provideCleanup(
}
return nil
}},
{"ScheduledTestRunnerService", func() error {
if scheduledTestRunner != nil {
scheduledTestRunner.Stop()
}
return nil
}},
}
infraSteps := []cleanupStep{

View File

@@ -195,7 +195,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db)
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -225,7 +229,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -273,6 +278,7 @@ func provideCleanup(
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -402,6 +408,12 @@ func provideCleanup(
}
return nil
}},
{"ScheduledTestRunnerService", func() error {
if scheduledTestRunner != nil {
scheduledTestRunner.Stop()
}
return nil
}},
}
infraSteps := []cleanupStep{

View File

@@ -74,6 +74,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
geminiOAuthSvc,
antigravityOAuthSvc,
nil, // openAIGateway
nil, // scheduledTestRunner
)
require.NotPanics(t, func() {

View File

@@ -41,6 +41,8 @@ type Account struct {
ProxyID *int64 `json:"proxy_id,omitempty"`
// Concurrency holds the value of the "concurrency" field.
Concurrency int `json:"concurrency,omitempty"`
// LoadFactor holds the value of the "load_factor" field.
LoadFactor *int `json:"load_factor,omitempty"`
// Priority holds the value of the "priority" field.
Priority int `json:"priority,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
@@ -143,7 +145,7 @@ func (*Account) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case account.FieldRateMultiplier:
values[i] = new(sql.NullFloat64)
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldLoadFactor, account.FieldPriority:
values[i] = new(sql.NullInt64)
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus:
values[i] = new(sql.NullString)
@@ -243,6 +245,13 @@ func (_m *Account) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Concurrency = int(value.Int64)
}
case account.FieldLoadFactor:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field load_factor", values[i])
} else if value.Valid {
_m.LoadFactor = new(int)
*_m.LoadFactor = int(value.Int64)
}
case account.FieldPriority:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field priority", values[i])
@@ -445,6 +454,11 @@ func (_m *Account) String() string {
builder.WriteString("concurrency=")
builder.WriteString(fmt.Sprintf("%v", _m.Concurrency))
builder.WriteString(", ")
if v := _m.LoadFactor; v != nil {
builder.WriteString("load_factor=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("priority=")
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
builder.WriteString(", ")

View File

@@ -37,6 +37,8 @@ const (
FieldProxyID = "proxy_id"
// FieldConcurrency holds the string denoting the concurrency field in the database.
FieldConcurrency = "concurrency"
// FieldLoadFactor holds the string denoting the load_factor field in the database.
FieldLoadFactor = "load_factor"
// FieldPriority holds the string denoting the priority field in the database.
FieldPriority = "priority"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
@@ -121,6 +123,7 @@ var Columns = []string{
FieldExtra,
FieldProxyID,
FieldConcurrency,
FieldLoadFactor,
FieldPriority,
FieldRateMultiplier,
FieldStatus,
@@ -250,6 +253,11 @@ func ByConcurrency(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldConcurrency, opts...).ToFunc()
}
// ByLoadFactor orders the results by the load_factor field.
func ByLoadFactor(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLoadFactor, opts...).ToFunc()
}
// ByPriority orders the results by the priority field.
func ByPriority(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPriority, opts...).ToFunc()

View File

@@ -100,6 +100,11 @@ func Concurrency(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldConcurrency, v))
}
// LoadFactor applies equality check predicate on the "load_factor" field. It's identical to LoadFactorEQ.
func LoadFactor(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
}
// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ.
func Priority(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPriority, v))
@@ -650,6 +655,56 @@ func ConcurrencyLTE(v int) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldConcurrency, v))
}
// LoadFactorEQ applies the EQ predicate on the "load_factor" field.
func LoadFactorEQ(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
}
// LoadFactorNEQ applies the NEQ predicate on the "load_factor" field.
func LoadFactorNEQ(v int) predicate.Account {
return predicate.Account(sql.FieldNEQ(FieldLoadFactor, v))
}
// LoadFactorIn applies the In predicate on the "load_factor" field.
func LoadFactorIn(vs ...int) predicate.Account {
return predicate.Account(sql.FieldIn(FieldLoadFactor, vs...))
}
// LoadFactorNotIn applies the NotIn predicate on the "load_factor" field.
func LoadFactorNotIn(vs ...int) predicate.Account {
return predicate.Account(sql.FieldNotIn(FieldLoadFactor, vs...))
}
// LoadFactorGT applies the GT predicate on the "load_factor" field.
func LoadFactorGT(v int) predicate.Account {
return predicate.Account(sql.FieldGT(FieldLoadFactor, v))
}
// LoadFactorGTE applies the GTE predicate on the "load_factor" field.
func LoadFactorGTE(v int) predicate.Account {
return predicate.Account(sql.FieldGTE(FieldLoadFactor, v))
}
// LoadFactorLT applies the LT predicate on the "load_factor" field.
func LoadFactorLT(v int) predicate.Account {
return predicate.Account(sql.FieldLT(FieldLoadFactor, v))
}
// LoadFactorLTE applies the LTE predicate on the "load_factor" field.
func LoadFactorLTE(v int) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldLoadFactor, v))
}
// LoadFactorIsNil applies the IsNil predicate on the "load_factor" field.
func LoadFactorIsNil() predicate.Account {
return predicate.Account(sql.FieldIsNull(FieldLoadFactor))
}
// LoadFactorNotNil applies the NotNil predicate on the "load_factor" field.
func LoadFactorNotNil() predicate.Account {
return predicate.Account(sql.FieldNotNull(FieldLoadFactor))
}
// PriorityEQ applies the EQ predicate on the "priority" field.
func PriorityEQ(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPriority, v))

View File

@@ -139,6 +139,20 @@ func (_c *AccountCreate) SetNillableConcurrency(v *int) *AccountCreate {
return _c
}
// SetLoadFactor sets the "load_factor" field.
func (_c *AccountCreate) SetLoadFactor(v int) *AccountCreate {
_c.mutation.SetLoadFactor(v)
return _c
}
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
func (_c *AccountCreate) SetNillableLoadFactor(v *int) *AccountCreate {
if v != nil {
_c.SetLoadFactor(*v)
}
return _c
}
// SetPriority sets the "priority" field.
func (_c *AccountCreate) SetPriority(v int) *AccountCreate {
_c.mutation.SetPriority(v)
@@ -623,6 +637,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
_spec.SetField(account.FieldConcurrency, field.TypeInt, value)
_node.Concurrency = value
}
if value, ok := _c.mutation.LoadFactor(); ok {
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
_node.LoadFactor = &value
}
if value, ok := _c.mutation.Priority(); ok {
_spec.SetField(account.FieldPriority, field.TypeInt, value)
_node.Priority = value
@@ -936,6 +954,30 @@ func (u *AccountUpsert) AddConcurrency(v int) *AccountUpsert {
return u
}
// SetLoadFactor sets the "load_factor" field.
func (u *AccountUpsert) SetLoadFactor(v int) *AccountUpsert {
u.Set(account.FieldLoadFactor, v)
return u
}
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
func (u *AccountUpsert) UpdateLoadFactor() *AccountUpsert {
u.SetExcluded(account.FieldLoadFactor)
return u
}
// AddLoadFactor adds v to the "load_factor" field.
func (u *AccountUpsert) AddLoadFactor(v int) *AccountUpsert {
u.Add(account.FieldLoadFactor, v)
return u
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (u *AccountUpsert) ClearLoadFactor() *AccountUpsert {
u.SetNull(account.FieldLoadFactor)
return u
}
// SetPriority sets the "priority" field.
func (u *AccountUpsert) SetPriority(v int) *AccountUpsert {
u.Set(account.FieldPriority, v)
@@ -1419,6 +1461,34 @@ func (u *AccountUpsertOne) UpdateConcurrency() *AccountUpsertOne {
})
}
// SetLoadFactor sets the "load_factor" field.
func (u *AccountUpsertOne) SetLoadFactor(v int) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.SetLoadFactor(v)
})
}
// AddLoadFactor adds v to the "load_factor" field.
func (u *AccountUpsertOne) AddLoadFactor(v int) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.AddLoadFactor(v)
})
}
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
func (u *AccountUpsertOne) UpdateLoadFactor() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.UpdateLoadFactor()
})
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (u *AccountUpsertOne) ClearLoadFactor() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.ClearLoadFactor()
})
}
// SetPriority sets the "priority" field.
func (u *AccountUpsertOne) SetPriority(v int) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
@@ -2113,6 +2183,34 @@ func (u *AccountUpsertBulk) UpdateConcurrency() *AccountUpsertBulk {
})
}
// SetLoadFactor sets the "load_factor" field.
func (u *AccountUpsertBulk) SetLoadFactor(v int) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.SetLoadFactor(v)
})
}
// AddLoadFactor adds v to the "load_factor" field.
func (u *AccountUpsertBulk) AddLoadFactor(v int) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.AddLoadFactor(v)
})
}
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
func (u *AccountUpsertBulk) UpdateLoadFactor() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.UpdateLoadFactor()
})
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (u *AccountUpsertBulk) ClearLoadFactor() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.ClearLoadFactor()
})
}
// SetPriority sets the "priority" field.
func (u *AccountUpsertBulk) SetPriority(v int) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {

View File

@@ -172,6 +172,33 @@ func (_u *AccountUpdate) AddConcurrency(v int) *AccountUpdate {
return _u
}
// SetLoadFactor sets the "load_factor" field.
func (_u *AccountUpdate) SetLoadFactor(v int) *AccountUpdate {
_u.mutation.ResetLoadFactor()
_u.mutation.SetLoadFactor(v)
return _u
}
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
func (_u *AccountUpdate) SetNillableLoadFactor(v *int) *AccountUpdate {
if v != nil {
_u.SetLoadFactor(*v)
}
return _u
}
// AddLoadFactor adds value to the "load_factor" field.
func (_u *AccountUpdate) AddLoadFactor(v int) *AccountUpdate {
_u.mutation.AddLoadFactor(v)
return _u
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (_u *AccountUpdate) ClearLoadFactor() *AccountUpdate {
_u.mutation.ClearLoadFactor()
return _u
}
// SetPriority sets the "priority" field.
func (_u *AccountUpdate) SetPriority(v int) *AccountUpdate {
_u.mutation.ResetPriority()
@@ -684,6 +711,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedConcurrency(); ok {
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
}
if value, ok := _u.mutation.LoadFactor(); ok {
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedLoadFactor(); ok {
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
}
if _u.mutation.LoadFactorCleared() {
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
}
if value, ok := _u.mutation.Priority(); ok {
_spec.SetField(account.FieldPriority, field.TypeInt, value)
}
@@ -1063,6 +1099,33 @@ func (_u *AccountUpdateOne) AddConcurrency(v int) *AccountUpdateOne {
return _u
}
// SetLoadFactor sets the "load_factor" field.
func (_u *AccountUpdateOne) SetLoadFactor(v int) *AccountUpdateOne {
_u.mutation.ResetLoadFactor()
_u.mutation.SetLoadFactor(v)
return _u
}
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
func (_u *AccountUpdateOne) SetNillableLoadFactor(v *int) *AccountUpdateOne {
if v != nil {
_u.SetLoadFactor(*v)
}
return _u
}
// AddLoadFactor adds value to the "load_factor" field.
func (_u *AccountUpdateOne) AddLoadFactor(v int) *AccountUpdateOne {
_u.mutation.AddLoadFactor(v)
return _u
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (_u *AccountUpdateOne) ClearLoadFactor() *AccountUpdateOne {
_u.mutation.ClearLoadFactor()
return _u
}
// SetPriority sets the "priority" field.
func (_u *AccountUpdateOne) SetPriority(v int) *AccountUpdateOne {
_u.mutation.ResetPriority()
@@ -1605,6 +1668,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
if value, ok := _u.mutation.AddedConcurrency(); ok {
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
}
if value, ok := _u.mutation.LoadFactor(); ok {
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedLoadFactor(); ok {
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
}
if _u.mutation.LoadFactorCleared() {
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
}
if value, ok := _u.mutation.Priority(); ok {
_spec.SetField(account.FieldPriority, field.TypeInt, value)
}

View File

@@ -106,6 +106,7 @@ var (
{Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "concurrency", Type: field.TypeInt, Default: 3},
{Name: "load_factor", Type: field.TypeInt, Nullable: true},
{Name: "priority", Type: field.TypeInt, Default: 50},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
@@ -132,7 +133,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "accounts_proxies_proxy",
Columns: []*schema.Column{AccountsColumns[27]},
Columns: []*schema.Column{AccountsColumns[28]},
RefColumns: []*schema.Column{ProxiesColumns[0]},
OnDelete: schema.SetNull,
},
@@ -151,52 +152,52 @@ var (
{
Name: "account_status",
Unique: false,
Columns: []*schema.Column{AccountsColumns[13]},
Columns: []*schema.Column{AccountsColumns[14]},
},
{
Name: "account_proxy_id",
Unique: false,
Columns: []*schema.Column{AccountsColumns[27]},
Columns: []*schema.Column{AccountsColumns[28]},
},
{
Name: "account_priority",
Unique: false,
Columns: []*schema.Column{AccountsColumns[11]},
Columns: []*schema.Column{AccountsColumns[12]},
},
{
Name: "account_last_used_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[15]},
Columns: []*schema.Column{AccountsColumns[16]},
},
{
Name: "account_schedulable",
Unique: false,
Columns: []*schema.Column{AccountsColumns[18]},
Columns: []*schema.Column{AccountsColumns[19]},
},
{
Name: "account_rate_limited_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[19]},
Columns: []*schema.Column{AccountsColumns[20]},
},
{
Name: "account_rate_limit_reset_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[20]},
Columns: []*schema.Column{AccountsColumns[21]},
},
{
Name: "account_overload_until",
Unique: false,
Columns: []*schema.Column{AccountsColumns[21]},
Columns: []*schema.Column{AccountsColumns[22]},
},
{
Name: "account_platform_priority",
Unique: false,
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]},
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[12]},
},
{
Name: "account_priority_status",
Unique: false,
Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]},
Columns: []*schema.Column{AccountsColumns[12], AccountsColumns[14]},
},
{
Name: "account_deleted_at",

View File

@@ -2260,6 +2260,8 @@ type AccountMutation struct {
extra *map[string]interface{}
concurrency *int
addconcurrency *int
load_factor *int
addload_factor *int
priority *int
addpriority *int
rate_multiplier *float64
@@ -2845,6 +2847,76 @@ func (m *AccountMutation) ResetConcurrency() {
m.addconcurrency = nil
}
// SetLoadFactor sets the "load_factor" field.
func (m *AccountMutation) SetLoadFactor(i int) {
m.load_factor = &i
m.addload_factor = nil
}
// LoadFactor returns the value of the "load_factor" field in the mutation.
func (m *AccountMutation) LoadFactor() (r int, exists bool) {
v := m.load_factor
if v == nil {
return
}
return *v, true
}
// OldLoadFactor returns the old "load_factor" field's value of the Account entity.
// If the Account object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AccountMutation) OldLoadFactor(ctx context.Context) (v *int, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldLoadFactor is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldLoadFactor requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldLoadFactor: %w", err)
}
return oldValue.LoadFactor, nil
}
// AddLoadFactor adds i to the "load_factor" field.
func (m *AccountMutation) AddLoadFactor(i int) {
if m.addload_factor != nil {
*m.addload_factor += i
} else {
m.addload_factor = &i
}
}
// AddedLoadFactor returns the value that was added to the "load_factor" field in this mutation.
func (m *AccountMutation) AddedLoadFactor() (r int, exists bool) {
v := m.addload_factor
if v == nil {
return
}
return *v, true
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (m *AccountMutation) ClearLoadFactor() {
m.load_factor = nil
m.addload_factor = nil
m.clearedFields[account.FieldLoadFactor] = struct{}{}
}
// LoadFactorCleared returns if the "load_factor" field was cleared in this mutation.
func (m *AccountMutation) LoadFactorCleared() bool {
_, ok := m.clearedFields[account.FieldLoadFactor]
return ok
}
// ResetLoadFactor resets all changes to the "load_factor" field.
func (m *AccountMutation) ResetLoadFactor() {
m.load_factor = nil
m.addload_factor = nil
delete(m.clearedFields, account.FieldLoadFactor)
}
// SetPriority sets the "priority" field.
func (m *AccountMutation) SetPriority(i int) {
m.priority = &i
@@ -3773,7 +3845,7 @@ func (m *AccountMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *AccountMutation) Fields() []string {
fields := make([]string, 0, 27)
fields := make([]string, 0, 28)
if m.created_at != nil {
fields = append(fields, account.FieldCreatedAt)
}
@@ -3807,6 +3879,9 @@ func (m *AccountMutation) Fields() []string {
if m.concurrency != nil {
fields = append(fields, account.FieldConcurrency)
}
if m.load_factor != nil {
fields = append(fields, account.FieldLoadFactor)
}
if m.priority != nil {
fields = append(fields, account.FieldPriority)
}
@@ -3885,6 +3960,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
return m.ProxyID()
case account.FieldConcurrency:
return m.Concurrency()
case account.FieldLoadFactor:
return m.LoadFactor()
case account.FieldPriority:
return m.Priority()
case account.FieldRateMultiplier:
@@ -3948,6 +4025,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldProxyID(ctx)
case account.FieldConcurrency:
return m.OldConcurrency(ctx)
case account.FieldLoadFactor:
return m.OldLoadFactor(ctx)
case account.FieldPriority:
return m.OldPriority(ctx)
case account.FieldRateMultiplier:
@@ -4066,6 +4145,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
}
m.SetConcurrency(v)
return nil
case account.FieldLoadFactor:
v, ok := value.(int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetLoadFactor(v)
return nil
case account.FieldPriority:
v, ok := value.(int)
if !ok {
@@ -4189,6 +4275,9 @@ func (m *AccountMutation) AddedFields() []string {
if m.addconcurrency != nil {
fields = append(fields, account.FieldConcurrency)
}
if m.addload_factor != nil {
fields = append(fields, account.FieldLoadFactor)
}
if m.addpriority != nil {
fields = append(fields, account.FieldPriority)
}
@@ -4205,6 +4294,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) {
switch name {
case account.FieldConcurrency:
return m.AddedConcurrency()
case account.FieldLoadFactor:
return m.AddedLoadFactor()
case account.FieldPriority:
return m.AddedPriority()
case account.FieldRateMultiplier:
@@ -4225,6 +4316,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error {
}
m.AddConcurrency(v)
return nil
case account.FieldLoadFactor:
v, ok := value.(int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddLoadFactor(v)
return nil
case account.FieldPriority:
v, ok := value.(int)
if !ok {
@@ -4256,6 +4354,9 @@ func (m *AccountMutation) ClearedFields() []string {
if m.FieldCleared(account.FieldProxyID) {
fields = append(fields, account.FieldProxyID)
}
if m.FieldCleared(account.FieldLoadFactor) {
fields = append(fields, account.FieldLoadFactor)
}
if m.FieldCleared(account.FieldErrorMessage) {
fields = append(fields, account.FieldErrorMessage)
}
@@ -4312,6 +4413,9 @@ func (m *AccountMutation) ClearField(name string) error {
case account.FieldProxyID:
m.ClearProxyID()
return nil
case account.FieldLoadFactor:
m.ClearLoadFactor()
return nil
case account.FieldErrorMessage:
m.ClearErrorMessage()
return nil
@@ -4386,6 +4490,9 @@ func (m *AccountMutation) ResetField(name string) error {
case account.FieldConcurrency:
m.ResetConcurrency()
return nil
case account.FieldLoadFactor:
m.ResetLoadFactor()
return nil
case account.FieldPriority:
m.ResetPriority()
return nil
@@ -10191,7 +10298,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 30)
fields := make([]string, 0, 31)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}

View File

@@ -212,29 +212,29 @@ func init() {
// account.DefaultConcurrency holds the default value on creation for the concurrency field.
account.DefaultConcurrency = accountDescConcurrency.Default.(int)
// accountDescPriority is the schema descriptor for priority field.
accountDescPriority := accountFields[8].Descriptor()
accountDescPriority := accountFields[9].Descriptor()
// account.DefaultPriority holds the default value on creation for the priority field.
account.DefaultPriority = accountDescPriority.Default.(int)
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
accountDescRateMultiplier := accountFields[9].Descriptor()
accountDescRateMultiplier := accountFields[10].Descriptor()
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
// accountDescStatus is the schema descriptor for status field.
accountDescStatus := accountFields[10].Descriptor()
accountDescStatus := accountFields[11].Descriptor()
// account.DefaultStatus holds the default value on creation for the status field.
account.DefaultStatus = accountDescStatus.Default.(string)
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
accountDescAutoPauseOnExpired := accountFields[15].Descriptor()
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
// accountDescSchedulable is the schema descriptor for schedulable field.
accountDescSchedulable := accountFields[15].Descriptor()
accountDescSchedulable := accountFields[16].Descriptor()
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
accountDescSessionWindowStatus := accountFields[23].Descriptor()
accountDescSessionWindowStatus := accountFields[24].Descriptor()
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
accountgroupFields := schema.AccountGroup{}.Fields()

View File

@@ -97,6 +97,8 @@ func (Account) Fields() []ent.Field {
field.Int("concurrency").
Default(3),
field.Int("load_factor").Optional().Nillable(),
// priority: 账户优先级,数值越小优先级越高
// 调度器会优先使用高优先级的账户
field.Int("priority").

View File

@@ -124,6 +124,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
@@ -171,8 +173,6 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -182,7 +182,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -203,6 +202,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -285,6 +286,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -398,8 +403,6 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
@@ -455,8 +458,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=

View File

@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
type GatewayOpenAIWSConfig struct {
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false关闭时保持 legacy 行为)
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
// IngressModeDefault: ingress 默认模式off/shared/dedicated
// IngressModeDefault: ingress 默认模式off/ctx_pool/passthrough
IngressModeDefault string `mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true
Enabled bool `mapstructure:"enabled"`
@@ -1227,7 +1227,7 @@ func setDefaults() {
// Ops (vNext)
viper.SetDefault("ops.enabled", true)
viper.SetDefault("ops.use_preaggregated_tables", false)
viper.SetDefault("ops.use_preaggregated_tables", true)
viper.SetDefault("ops.cleanup.enabled", true)
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
// Retention days: vNext defaults to 30 days across ops datasets.
@@ -1335,7 +1335,7 @@ func setDefaults() {
// OpenAI Responses WebSocket默认开启可通过 force_http 紧急回滚)
viper.SetDefault("gateway.openai_ws.enabled", true)
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
viper.SetDefault("gateway.openai_ws.force_http", false)
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
switch mode {
case "off", "shared", "dedicated":
case "off", "ctx_pool", "passthrough":
case "shared", "dedicated":
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
default:
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
}
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {

View File

@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
}
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
}
}
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
},
{
name: "ingress_mode_default 必须为 off|shared|dedicated",
name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
wantErr: "gateway.openai_ws.ingress_mode_default",
},

View File

@@ -102,6 +102,7 @@ type CreateAccountRequest struct {
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
LoadFactor *int `json:"load_factor"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
@@ -120,6 +121,7 @@ type UpdateAccountRequest struct {
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
LoadFactor *int `json:"load_factor"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
@@ -135,6 +137,7 @@ type BulkUpdateAccountsRequest struct {
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
LoadFactor *int `json:"load_factor"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
GroupIDs *[]int64 `json:"group_ids"`
@@ -217,6 +220,7 @@ func (h *AccountHandler) List(c *gin.Context) {
if len(search) > 100 {
search = search[:100]
}
lite := parseBoolQueryWithDefault(c.Query("lite"), false)
var groupID int64
if groupIDStr := c.Query("group"); groupIDStr != "" {
@@ -235,10 +239,16 @@ func (h *AccountHandler) List(c *gin.Context) {
accountIDs[i] = acc.ID
}
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
if err != nil {
// Log error but don't fail the request, just use 0 for all
concurrencyCounts = make(map[int64]int)
concurrencyCounts := make(map[int64]int)
var windowCosts map[int64]float64
var activeSessions map[int64]int
var rpmCounts map[int64]int
// 始终获取并发数Redis ZCARD极低开销
if h.concurrencyService != nil {
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
concurrencyCounts = cc
}
}
// 识别需要查询窗口费用、会话数和 RPM 的账号Anthropic OAuth/SetupToken 且启用了相应功能)
@@ -262,12 +272,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
// 并行获取窗口费用、活跃会话数和 RPM 计数
var windowCosts map[int64]float64
var activeSessions map[int64]int
var rpmCounts map[int64]int
// 获取 RPM 计数(批量查询)
// 始终获取 RPM 计数Redis GET极低开销
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
if rpmCounts == nil {
@@ -275,7 +280,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置
// 始终获取活跃会话数(Redis ZCARD低开销
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil {
@@ -283,32 +288,48 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
// 获取窗口费用(并行查询)
// 窗口费用获取lite 模式从快照缓存读取,非 lite 模式执行 PostgreSQL 查询后写入缓存
if len(windowCostAccountIDs) > 0 {
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
for i := range accounts {
acc := &accounts[i]
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
}
accCopy := acc // 闭包捕获
g.Go(func() error {
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
if lite {
// lite 模式:尝试从快照缓存读取
cacheKey := buildWindowCostCacheKey(windowCostAccountIDs)
if cached, ok := accountWindowCostCache.Get(cacheKey); ok {
if costs, ok := cached.Payload.(map[int64]float64); ok {
windowCosts = costs
}
return nil // 不返回错误,允许部分失败
})
}
// 缓存未命中则 windowCosts 保持 nil仅发生在服务刚启动时
} else {
// 非 lite 模式:执行 PostgreSQL 聚合查询(高开销)
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
for i := range accounts {
acc := &accounts[i]
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
}
accCopy := acc // 闭包捕获
g.Go(func() error {
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
}
return nil // 不返回错误,允许部分失败
})
}
_ = g.Wait()
// 查询完毕后写入快照缓存,供 lite 模式使用
cacheKey := buildWindowCostCacheKey(windowCostAccountIDs)
accountWindowCostCache.Set(cacheKey, windowCosts)
}
_ = g.Wait()
}
// Build response with concurrency info
@@ -344,7 +365,7 @@ func (h *AccountHandler) List(c *gin.Context) {
result[i] = item
}
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search)
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite)
if etag != "" {
c.Header("ETag", etag)
c.Header("Vary", "If-None-Match")
@@ -362,6 +383,7 @@ func buildAccountsListETag(
total int64,
page, pageSize int,
platform, accountType, status, search string,
lite bool,
) string {
payload := struct {
Total int64 `json:"total"`
@@ -371,6 +393,7 @@ func buildAccountsListETag(
AccountType string `json:"type"`
Status string `json:"status"`
Search string `json:"search"`
Lite bool `json:"lite"`
Items []AccountWithConcurrency `json:"items"`
}{
Total: total,
@@ -380,6 +403,7 @@ func buildAccountsListETag(
AccountType: accountType,
Status: status,
Search: search,
Lite: lite,
Items: items,
}
raw, err := json.Marshal(payload)
@@ -501,6 +525,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
LoadFactor: req.LoadFactor,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
@@ -570,6 +595,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
RateMultiplier: req.RateMultiplier,
LoadFactor: req.LoadFactor,
Status: req.Status,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
@@ -1096,6 +1122,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req.Concurrency != nil ||
req.Priority != nil ||
req.RateMultiplier != nil ||
req.LoadFactor != nil ||
req.Status != "" ||
req.Schedulable != nil ||
req.GroupIDs != nil ||
@@ -1114,6 +1141,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
LoadFactor: req.LoadFactor,
Status: req.Status,
Schedulable: req.Schedulable,
GroupIDs: req.GroupIDs,
@@ -1323,6 +1351,29 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// ResetQuota handles resetting account quota usage
// POST /api/v1/admin/accounts/:id/reset-quota
func (h *AccountHandler) ResetQuota(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
if err := h.adminService.ResetAccountQuota(c.Request.Context(), accountID); err != nil {
response.InternalError(c, "Failed to reset account quota: "+err.Error())
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// GetTempUnschedulable handles getting temporary unschedulable status
// GET /api/v1/admin/accounts/:id/temp-unschedulable
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {
@@ -1398,18 +1449,41 @@ func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
return
}
if len(req.AccountIDs) == 0 {
accountIDs := normalizeInt64IDList(req.AccountIDs)
if len(accountIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs)
if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"stats": stats})
payload := gin.H{"stats": stats}
cached := accountTodayStatsBatchCache.Set(cacheKey, payload)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// SetSchedulableRequest represents the request body for setting schedulable status

View File

@@ -0,0 +1,25 @@
package admin
import (
"strconv"
"strings"
"time"
)
var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second)
func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string {
if len(accountIDs) == 0 {
return "accounts_today_stats_empty"
}
var b strings.Builder
b.Grow(len(accountIDs) * 6)
_, _ = b.WriteString("accounts_today_stats:")
for i, id := range accountIDs {
if i > 0 {
_ = b.WriteByte(',')
}
_, _ = b.WriteString(strconv.FormatInt(id, 10))
}
return b.String()
}

View File

@@ -0,0 +1,25 @@
package admin
import (
"strconv"
"strings"
"time"
)
var accountWindowCostCache = newSnapshotCache(30 * time.Second)
func buildWindowCostCacheKey(accountIDs []int64) string {
if len(accountIDs) == 0 {
return "accounts_window_cost_empty"
}
var b strings.Builder
b.Grow(len(accountIDs) * 6)
_, _ = b.WriteString("accounts_window_cost:")
for i, id := range accountIDs {
if i > 0 {
_ = b.WriteByte(',')
}
_, _ = b.WriteString(strconv.FormatInt(id, 10))
}
return b.String()
}

View File

@@ -425,5 +425,9 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return nil, service.ErrAPIKeyNotFound
}
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
return nil
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -1,6 +1,7 @@
package admin
import (
"encoding/json"
"errors"
"strconv"
"strings"
@@ -460,6 +461,9 @@ type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
@@ -469,18 +473,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
if len(req.UserIDs) == 0 {
userIDs := normalizeInt64IDList(req.UserIDs)
if len(userIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
keyRaw, _ := json.Marshal(struct {
UserIDs []int64 `json:"user_ids"`
}{
UserIDs: userIDs,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
payload := gin.H{"stats": stats}
dashboardBatchUsersUsageCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
@@ -497,16 +517,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return
}
if len(req.APIKeyIDs) == 0 {
apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs)
if len(apiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
keyRaw, _ := json.Marshal(struct {
APIKeyIDs []int64 `json:"api_key_ids"`
}{
APIKeyIDs: apiKeyIDs,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
payload := gin.H{"stats": stats}
dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}

View File

@@ -0,0 +1,292 @@
package admin
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
type dashboardSnapshotV2Stats struct {
usagestats.DashboardStats
Uptime int64 `json:"uptime"`
}
type dashboardSnapshotV2Response struct {
GeneratedAt string `json:"generated_at"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
Granularity string `json:"granularity"`
Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"`
Trend []usagestats.TrendDataPoint `json:"trend,omitempty"`
Models []usagestats.ModelStat `json:"models,omitempty"`
Groups []usagestats.GroupStat `json:"groups,omitempty"`
UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"`
}
type dashboardSnapshotV2Filters struct {
UserID int64
APIKeyID int64
AccountID int64
GroupID int64
Model string
RequestType *int16
Stream *bool
BillingType *int8
}
type dashboardSnapshotV2CacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Model string `json:"model"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
IncludeStats bool `json:"include_stats"`
IncludeTrend bool `json:"include_trend"`
IncludeModels bool `json:"include_models"`
IncludeGroups bool `json:"include_groups"`
IncludeUsersTrend bool `json:"include_users_trend"`
UsersTrendLimit int `json:"users_trend_limit"`
}
func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day"))
if granularity != "hour" {
granularity = "day"
}
includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true)
includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true)
includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true)
includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false)
includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false)
usersTrendLimit := 12
if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" {
if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 {
usersTrendLimit = parsed
}
}
filters, err := parseDashboardSnapshotV2Filters(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
UserID: filters.UserID,
APIKeyID: filters.APIKeyID,
AccountID: filters.AccountID,
GroupID: filters.GroupID,
Model: filters.Model,
RequestType: filters.RequestType,
Stream: filters.Stream,
BillingType: filters.BillingType,
IncludeStats: includeStats,
IncludeTrend: includeTrend,
IncludeModels: includeModels,
IncludeGroups: includeGroups,
IncludeUsersTrend: includeUsersTrend,
UsersTrendLimit: usersTrendLimit,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
resp := &dashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
StartDate: startTime.Format("2006-01-02"),
EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"),
Granularity: granularity,
}
if includeStats {
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
}
resp.Stats = &dashboardSnapshotV2Stats{
DashboardStats: *stats,
Uptime: int64(time.Since(h.startTime).Seconds()),
}
}
if includeTrend {
trend, err := h.dashboardService.GetUsageTrendWithFilters(
c.Request.Context(),
startTime,
endTime,
granularity,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.Model,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
}
resp.Trend = trend
}
if includeModels {
models, err := h.dashboardService.GetModelStatsWithFilters(
c.Request.Context(),
startTime,
endTime,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
}
resp.Models = models
}
if includeGroups {
groups, err := h.dashboardService.GetGroupStatsWithFilters(
c.Request.Context(),
startTime,
endTime,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
}
resp.Groups = groups
}
if includeUsersTrend {
usersTrend, err := h.dashboardService.GetUserUsageTrend(
c.Request.Context(),
startTime,
endTime,
granularity,
usersTrendLimit,
)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
}
resp.UsersTrend = usersTrend
}
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
}
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
filters := &dashboardSnapshotV2Filters{
Model: strings.TrimSpace(c.Query("model")),
}
if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.UserID = id
}
if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.APIKeyID = id
}
if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" {
id, err := strconv.ParseInt(accountIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.AccountID = id
}
if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" {
id, err := strconv.ParseInt(groupIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.GroupID = id
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
return nil, err
}
value := int16(parsed)
filters.RequestType = &value
} else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" {
streamVal, err := strconv.ParseBool(streamStr)
if err != nil {
return nil, err
}
filters.Stream = &streamVal
}
if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" {
v, err := strconv.ParseInt(billingTypeStr, 10, 8)
if err != nil {
return nil, err
}
bt := int8(v)
filters.BillingType = &bt
}
return filters, nil
}

View File

@@ -0,0 +1,25 @@
package admin
import "sort"
func normalizeInt64IDList(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
out := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
return out
}

View File

@@ -0,0 +1,57 @@
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNormalizeInt64IDList(t *testing.T) {
tests := []struct {
name string
in []int64
want []int64
}{
{"nil input", nil, nil},
{"empty input", []int64{}, nil},
{"single element", []int64{5}, []int64{5}},
{"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}},
{"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}},
{"zero filtered", []int64{0, 1, 2}, []int64{1, 2}},
{"negative filtered", []int64{-5, -1, 3}, []int64{3}},
{"all invalid", []int64{0, -1, -2}, []int64{}},
{"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := normalizeInt64IDList(tc.in)
if tc.want == nil {
require.Nil(t, got)
} else {
require.Equal(t, tc.want, got)
}
})
}
}
func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) {
tests := []struct {
name string
ids []int64
want string
}{
{"empty", nil, "accounts_today_stats_empty"},
{"single", []int64{42}, "accounts_today_stats:42"},
{"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := buildAccountTodayStatsBatchCacheKey(tc.ids)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -0,0 +1,145 @@
package admin
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
)
var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
type opsDashboardSnapshotV2Response struct {
GeneratedAt string `json:"generated_at"`
Overview *service.OpsDashboardOverview `json:"overview"`
ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"`
ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"`
}
type opsDashboardSnapshotV2CacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
QueryMode service.OpsQueryMode `json:"mode"`
BucketSecond int `json:"bucket_second"`
}
// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request.
// GET /api/v1/admin/ops/dashboard/snapshot-v2
func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Platform: filter.Platform,
GroupID: filter.GroupID,
QueryMode: filter.QueryMode,
BucketSecond: bucketSeconds,
})
cacheKey := string(keyRaw)
if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
var (
overview *service.OpsDashboardOverview
trend *service.OpsThroughputTrendResponse
errTrend *service.OpsErrorTrendResponse
)
g, gctx := errgroup.WithContext(c.Request.Context())
g.Go(func() error {
f := *filter
result, err := h.opsService.GetDashboardOverview(gctx, &f)
if err != nil {
return err
}
overview = result
return nil
})
g.Go(func() error {
f := *filter
result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds)
if err != nil {
return err
}
trend = result
return nil
})
g.Go(func() error {
f := *filter
result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds)
if err != nil {
return err
}
errTrend = result
return nil
})
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
resp := &opsDashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
Overview: overview,
ThroughputTrend: trend,
ErrorTrend: errTrend,
}
cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
}

View File

@@ -0,0 +1,155 @@
package admin
import (
"net/http"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ScheduledTestHandler handles admin scheduled-test-plan management.
type ScheduledTestHandler struct {
scheduledTestSvc *service.ScheduledTestService
}
// NewScheduledTestHandler creates a new ScheduledTestHandler.
func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler {
return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc}
}
type createScheduledTestPlanRequest struct {
AccountID int64 `json:"account_id" binding:"required"`
ModelID string `json:"model_id"`
CronExpression string `json:"cron_expression" binding:"required"`
Enabled *bool `json:"enabled"`
MaxResults int `json:"max_results"`
}
type updateScheduledTestPlanRequest struct {
ModelID string `json:"model_id"`
CronExpression string `json:"cron_expression"`
Enabled *bool `json:"enabled"`
MaxResults int `json:"max_results"`
}
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid account id")
return
}
plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, plans)
}
// Create POST /admin/scheduled-test-plans
func (h *ScheduledTestHandler) Create(c *gin.Context) {
var req createScheduledTestPlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
plan := &service.ScheduledTestPlan{
AccountID: req.AccountID,
ModelID: req.ModelID,
CronExpression: req.CronExpression,
Enabled: true,
MaxResults: req.MaxResults,
}
if req.Enabled != nil {
plan.Enabled = *req.Enabled
}
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
if err != nil {
response.BadRequest(c, err.Error())
return
}
c.JSON(http.StatusOK, created)
}
// Update PUT /admin/scheduled-test-plans/:id
func (h *ScheduledTestHandler) Update(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID)
if err != nil {
response.NotFound(c, "plan not found")
return
}
var req updateScheduledTestPlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if req.ModelID != "" {
existing.ModelID = req.ModelID
}
if req.CronExpression != "" {
existing.CronExpression = req.CronExpression
}
if req.Enabled != nil {
existing.Enabled = *req.Enabled
}
if req.MaxResults > 0 {
existing.MaxResults = req.MaxResults
}
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
if err != nil {
response.BadRequest(c, err.Error())
return
}
c.JSON(http.StatusOK, updated)
}
// Delete DELETE /admin/scheduled-test-plans/:id
func (h *ScheduledTestHandler) Delete(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
}
// ListResults GET /admin/scheduled-test-plans/:id/results
func (h *ScheduledTestHandler) ListResults(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
limit := 50
if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 {
limit = l
}
results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit)
if err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, results)
}

View File

@@ -77,6 +77,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
@@ -130,12 +131,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -426,50 +428,51 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
SoraClientEnabled: req.SoraClientEnabled,
CustomMenuItems: customMenuJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
SoraClientEnabled: req.SoraClientEnabled,
CustomMenuItems: customMenuJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -520,6 +523,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
@@ -598,6 +602,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
@@ -747,6 +754,18 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
return normalized
}
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
if len(a) != len(b) {
return false
@@ -800,7 +819,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
err := h.emailService.TestSMTPConnectionWithConfig(config)
if err != nil {
response.ErrorFrom(c, err)
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
return
}
@@ -886,7 +905,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
`
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
response.ErrorFrom(c, err)
response.BadRequest(c, "Failed to send test email: "+err.Error())
return
}

View File

@@ -0,0 +1,95 @@
package admin
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"strings"
"sync"
"time"
)
type snapshotCacheEntry struct {
ETag string
Payload any
ExpiresAt time.Time
}
type snapshotCache struct {
mu sync.RWMutex
ttl time.Duration
items map[string]snapshotCacheEntry
}
func newSnapshotCache(ttl time.Duration) *snapshotCache {
if ttl <= 0 {
ttl = 30 * time.Second
}
return &snapshotCache{
ttl: ttl,
items: make(map[string]snapshotCacheEntry),
}
}
func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) {
if c == nil || key == "" {
return snapshotCacheEntry{}, false
}
now := time.Now()
c.mu.RLock()
entry, ok := c.items[key]
c.mu.RUnlock()
if !ok {
return snapshotCacheEntry{}, false
}
if now.After(entry.ExpiresAt) {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
return snapshotCacheEntry{}, false
}
return entry, true
}
func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
if c == nil {
return snapshotCacheEntry{}
}
entry := snapshotCacheEntry{
ETag: buildETagFromAny(payload),
Payload: payload,
ExpiresAt: time.Now().Add(c.ttl),
}
if key == "" {
return entry
}
c.mu.Lock()
c.items[key] = entry
c.mu.Unlock()
return entry
}
func buildETagFromAny(payload any) string {
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
sum := sha256.Sum256(raw)
return "\"" + hex.EncodeToString(sum[:]) + "\""
}
func parseBoolQueryWithDefault(raw string, def bool) bool {
value := strings.TrimSpace(strings.ToLower(raw))
if value == "" {
return def
}
switch value {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return def
}
}

View File

@@ -0,0 +1,128 @@
//go:build unit
package admin
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSnapshotCache_SetAndGet(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
entry := c.Set("key1", map[string]string{"hello": "world"})
require.NotEmpty(t, entry.ETag)
require.NotNil(t, entry.Payload)
got, ok := c.Get("key1")
require.True(t, ok)
require.Equal(t, entry.ETag, got.ETag)
}
func TestSnapshotCache_Expiration(t *testing.T) {
c := newSnapshotCache(1 * time.Millisecond)
c.Set("key1", "value")
time.Sleep(5 * time.Millisecond)
_, ok := c.Get("key1")
require.False(t, ok, "expired entry should not be returned")
}
func TestSnapshotCache_GetEmptyKey(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
_, ok := c.Get("")
require.False(t, ok)
}
func TestSnapshotCache_GetMiss(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
_, ok := c.Get("nonexistent")
require.False(t, ok)
}
func TestSnapshotCache_NilReceiver(t *testing.T) {
var c *snapshotCache
_, ok := c.Get("key")
require.False(t, ok)
entry := c.Set("key", "value")
require.Empty(t, entry.ETag)
}
func TestSnapshotCache_SetEmptyKey(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
// Set with empty key should return entry but not store it
entry := c.Set("", "value")
require.NotEmpty(t, entry.ETag)
_, ok := c.Get("")
require.False(t, ok)
}
func TestSnapshotCache_DefaultTTL(t *testing.T) {
c := newSnapshotCache(0)
require.Equal(t, 30*time.Second, c.ttl)
c2 := newSnapshotCache(-1 * time.Second)
require.Equal(t, 30*time.Second, c2.ttl)
}
func TestSnapshotCache_ETagDeterministic(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
payload := map[string]int{"a": 1, "b": 2}
entry1 := c.Set("k1", payload)
entry2 := c.Set("k2", payload)
require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag")
}
func TestSnapshotCache_ETagFormat(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
entry := c.Set("k", "test")
// ETag should be quoted hex string: "abcdef..."
require.True(t, len(entry.ETag) > 2)
require.Equal(t, byte('"'), entry.ETag[0])
require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1])
}
func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
// channels are not JSON-serializable
etag := buildETagFromAny(make(chan int))
require.Empty(t, etag)
}
func TestParseBoolQueryWithDefault(t *testing.T) {
tests := []struct {
name string
raw string
def bool
want bool
}{
{"empty returns default true", "", true, true},
{"empty returns default false", "", false, false},
{"1", "1", false, true},
{"true", "true", false, true},
{"TRUE", "TRUE", false, true},
{"yes", "yes", false, true},
{"on", "on", false, true},
{"0", "0", true, false},
{"false", "false", true, false},
{"FALSE", "FALSE", true, false},
{"no", "no", true, false},
{"off", "off", true, false},
{"whitespace trimmed", " true ", false, true},
{"unknown returns default true", "maybe", true, true},
{"unknown returns default false", "maybe", false, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := parseBoolQueryWithDefault(tc.raw, tc.def)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -61,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct {
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
exactTotal := false
if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" {
parsed, err := strconv.ParseBool(exactTotalRaw)
if err != nil {
response.BadRequest(c, "Invalid exact_total value, use true or false")
return
}
exactTotal = parsed
}
// Parse filters
var userID, apiKeyID, accountID, groupID int64
@@ -167,6 +176,7 @@ func (h *UsageHandler) List(c *gin.Context) {
BillingType: billingType,
StartTime: startTime,
EndTime: endTime,
ExactTotal: exactTotal,
}
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)

View File

@@ -80,6 +80,29 @@ func TestAdminUsageListInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageListExactTotalTrue(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.True(t, repo.listFilters.ExactTotal)
}
func TestAdminUsageListInvalidExactTotal(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)

View File

@@ -1,7 +1,9 @@
package admin
import (
"encoding/json"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct {
Attributes map[int64]map[int64]string `json:"attributes"`
}
var userAttributesBatchCache = newSnapshotCache(30 * time.Second)
// AttributeDefinitionResponse represents attribute definition response
type AttributeDefinitionResponse struct {
ID int64 `json:"id"`
@@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
return
}
if len(req.UserIDs) == 0 {
userIDs := normalizeInt64IDList(req.UserIDs)
if len(userIDs) == 0 {
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
return
}
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
keyRaw, _ := json.Marshal(struct {
UserIDs []int64 `json:"user_ids"`
}{
UserIDs: userIDs,
})
cacheKey := string(keyRaw)
if cached, ok := userAttributesBatchCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
payload := BatchUserAttributesResponse{Attributes: attrs}
userAttributesBatchCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}

View File

@@ -91,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) {
Search: search,
Attributes: parseAttributeFilters(c),
}
if raw, ok := c.GetQuery("include_subscriptions"); ok {
includeSubscriptions := parseBoolQueryWithDefault(raw, true)
filters.IncludeSubscriptions = &includeSubscriptions
}
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
if err != nil {

View File

@@ -4,6 +4,7 @@ package handler
import (
"context"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -73,7 +74,23 @@ func (h *APIKeyHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
// Parse filter parameters
var filters service.APIKeyListFilters
if search := strings.TrimSpace(c.Query("search")); search != "" {
if len(search) > 100 {
search = search[:100]
}
filters.Search = search
}
filters.Status = c.Query("status")
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
gid, err := strconv.ParseInt(groupIDStr, 10, 64)
if err == nil {
filters.GroupID = &gid
}
}
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params, filters)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -183,6 +183,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
LoadFactor: a.LoadFactor,
Priority: a.Priority,
RateMultiplier: a.BillingRateMultiplier(),
Status: a.Status,
@@ -248,6 +249,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
}
}
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
if a.Type == service.AccountTypeAPIKey {
if limit := a.GetQuotaLimit(); limit > 0 {
out.QuotaLimit = &limit
}
used := a.GetQuotaUsed()
if out.QuotaLimit != nil {
out.QuotaUsed = &used
}
}
return out
}

View File

@@ -17,13 +17,14 @@ type CustomMenuItem struct {
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -88,28 +89,29 @@ type DefaultSubscriptionSetting struct {
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"`
Version string `json:"version"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"`
Version string `json:"version"`
}
// SoraS3Settings Sora S3 存储配置 DTO响应用不含敏感字段

View File

@@ -131,6 +131,7 @@ type Account struct {
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
LoadFactor *int `json:"load_factor,omitempty"`
Priority int `json:"priority"`
RateMultiplier float64 `json:"rate_multiplier"`
Status string `json:"status"`
@@ -185,6 +186,10 @@ type Account struct {
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
// API Key 账号配额限制
QuotaLimit *float64 `json:"quota_limit,omitempty"`
QuotaUsed *float64 `json:"quota_used,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`

View File

@@ -27,6 +27,7 @@ type AdminHandlers struct {
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
APIKey *admin.AdminAPIKeyHandler
ScheduledTest *admin.ScheduledTestHandler
}
// Handlers contains all HTTP handlers

View File

@@ -0,0 +1,192 @@
package handler
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
var handlerStructuredLogCaptureMu sync.Mutex
type handlerInMemoryLogSink struct {
mu sync.Mutex
events []*logger.LogEvent
}
func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
if event == nil {
return
}
cloned := *event
if event.Fields != nil {
cloned.Fields = make(map[string]any, len(event.Fields))
for k, v := range event.Fields {
cloned.Fields[k] = v
}
}
s.mu.Lock()
s.events = append(s.events, &cloned)
s.mu.Unlock()
}
func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool {
s.mu.Lock()
defer s.mu.Unlock()
wantLevel := strings.ToLower(strings.TrimSpace(level))
for _, ev := range s.events {
if ev == nil {
continue
}
if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel {
return true
}
}
return false
}
func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool {
s.mu.Lock()
defer s.mu.Unlock()
for _, ev := range s.events {
if ev == nil || ev.Fields == nil {
continue
}
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
return true
}
}
return false
}
func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) {
t.Helper()
handlerStructuredLogCaptureMu.Lock()
err := logger.Init(logger.InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: logger.SamplingOptions{Enabled: false},
})
require.NoError(t, err)
sink := &handlerInMemoryLogSink{}
logger.SetSink(sink)
return sink, func() {
logger.SetSink(nil)
handlerStructuredLogCaptureMu.Unlock()
}
}
func TestIsOpenAIRemoteCompactPath(t *testing.T) {
require.False(t, isOpenAIRemoteCompactPath(nil))
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
require.True(t, isOpenAIRemoteCompactPath(c))
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil)
require.True(t, isOpenAIRemoteCompactPath(c))
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
require.False(t, isOpenAIRemoteCompactPath(c))
}
func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Set(opsModelKey, "gpt-5.3-codex")
c.Set(opsAccountIDKey, int64(123))
c.Header("x-request-id", "rid-compact-ok")
c.Status(http.StatusOK)
h := &OpenAIGatewayHandler{}
h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond))
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded"))
require.True(t, logSink.ContainsFieldValue("status_code", "200"))
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex"))
require.True(t, logSink.ContainsFieldValue("account_id", "123"))
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok"))
}
func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Status(http.StatusBadGateway)
h := &OpenAIGatewayHandler{}
h.logOpenAIRemoteCompactOutcome(c, time.Now())
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed"))
require.True(t, logSink.ContainsFieldValue("status_code", "502"))
require.True(t, logSink.ContainsFieldValue("path", "/responses/compact"))
}
func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
c.Status(http.StatusOK)
h := &OpenAIGatewayHandler{}
h.logOpenAIRemoteCompactOutcome(c, time.Now())
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
}
func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
h := &OpenAIGatewayHandler{}
h.Responses(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
require.True(t, logSink.ContainsFieldValue("status_code", "401"))
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
}

View File

@@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
cfg *config.Config
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -61,6 +62,7 @@ func NewOpenAIGatewayHandler(
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
}
}
@@ -70,6 +72,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
streamStarted := false
defer h.recoverResponsesPanic(c, &streamStarted)
compactStartedAt := time.Now()
defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt)
setOpenAIClientTransportHTTP(c)
requestStart := time.Now()
@@ -340,6 +344,86 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
}
func isOpenAIRemoteCompactPath(c *gin.Context) bool {
if c == nil || c.Request == nil || c.Request.URL == nil {
return false
}
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
return strings.HasSuffix(normalizedPath, "/responses/compact")
}
func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) {
if !isOpenAIRemoteCompactPath(c) {
return
}
var (
ctx = context.Background()
path string
status int
)
if c != nil {
if c.Request != nil {
ctx = c.Request.Context()
if c.Request.URL != nil {
path = strings.TrimSpace(c.Request.URL.Path)
}
}
if c.Writer != nil {
status = c.Writer.Status()
}
}
outcome := "failed"
if status >= 200 && status < 300 {
outcome = "succeeded"
}
latencyMs := time.Since(startedAt).Milliseconds()
if latencyMs < 0 {
latencyMs = 0
}
fields := []zap.Field{
zap.String("component", "handler.openai_gateway.responses"),
zap.Bool("remote_compact", true),
zap.String("compact_outcome", outcome),
zap.Int("status_code", status),
zap.Int64("latency_ms", latencyMs),
zap.String("path", path),
zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI),
}
if c != nil {
if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" {
fields = append(fields, zap.String("request_user_agent", userAgent))
}
if v, ok := c.Get(opsModelKey); ok {
if model, ok := v.(string); ok && strings.TrimSpace(model) != "" {
fields = append(fields, zap.String("request_model", strings.TrimSpace(model)))
}
}
if v, ok := c.Get(opsAccountIDKey); ok {
if accountID, ok := v.(int64); ok && accountID > 0 {
fields = append(fields, zap.Int64("account_id", accountID))
}
}
if c.Writer != nil {
if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" {
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
} else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" {
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
}
}
}
log := logger.FromContext(ctx).With(fields...)
if outcome == "succeeded" {
log.Info("codex.remote_compact.succeeded")
return
}
log.Warn("codex.remote_compact.failed")
}
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
return true

View File

@@ -32,27 +32,28 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled,
Version: h.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled,
Version: h.version,
})
}

View File

@@ -996,7 +996,7 @@ func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*se
}
func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
@@ -2132,6 +2132,14 @@ func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service
return 0, nil
}
func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
return nil
}
func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
return nil
}
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)

View File

@@ -216,6 +216,14 @@ func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
return 0, nil
}
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
func (r *stubAccountRepo) listSchedulable() []service.Account {
var result []service.Account
for _, acc := range r.accounts {

View File

@@ -30,6 +30,7 @@ func ProvideAdminHandlers(
userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
apiKeyHandler *admin.AdminAPIKeyHandler,
scheduledTestHandler *admin.ScheduledTestHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -53,6 +54,7 @@ func ProvideAdminHandlers(
UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
APIKey: apiKeyHandler,
ScheduledTest: scheduledTestHandler,
}
}
@@ -141,6 +143,7 @@ var ProviderSet = wire.NewSet(
admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
admin.NewAdminAPIKeyHandler,
admin.NewScheduledTestHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,

View File

@@ -15,6 +15,7 @@ type Model struct {
// DefaultModels OpenAI models list
var DefaultModels = []Model{
{ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"},
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},

View File

@@ -57,25 +57,28 @@ type DashboardStats struct {
// TrendDataPoint represents a single point in trend data
type TrendDataPoint struct {
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheCreationTokens int64 `json:"cache_creation_tokens"`
CacheReadTokens int64 `json:"cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ModelStat represents usage statistics for a single model
type ModelStat struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheCreationTokens int64 `json:"cache_creation_tokens"`
CacheReadTokens int64 `json:"cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GroupStat represents usage statistics for a single group
@@ -154,6 +157,8 @@ type UsageLogFilters struct {
BillingType *int8
StartTime *time.Time
EndTime *time.Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
ExactTotal bool
}
// UsageStats represents usage statistics

View File

@@ -84,6 +84,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.LoadFactor != nil {
builder.SetLoadFactor(*account.LoadFactor)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
@@ -318,6 +321,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.LoadFactor != nil {
builder.SetLoadFactor(*account.LoadFactor)
} else {
builder.ClearLoadFactor()
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
@@ -437,6 +445,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
switch status {
case "rate_limited":
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
case "temp_unschedulable":
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.And(
entsql.Not(entsql.IsNull(col)),
entsql.GT(col, entsql.Expr("NOW()")),
))
}))
default:
q = q.Where(dbaccount.StatusEQ(status))
}
@@ -640,7 +656,17 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
SetStatus(service.StatusActive).
SetErrorMessage("").
Save(ctx)
return err
if err != nil {
return err
}
// 清除临时不可调度状态,重置 401 升级链
_, _ = r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = NULL,
temp_unschedulable_reason = NULL
WHERE id = $1 AND deleted_at IS NULL
`, id)
return nil
}
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
@@ -1205,6 +1231,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.RateMultiplier)
idx++
}
if updates.LoadFactor != nil {
if *updates.LoadFactor <= 0 {
setClauses = append(setClauses, "load_factor = NULL")
} else {
setClauses = append(setClauses, "load_factor = $"+itoa(idx))
args = append(args, *updates.LoadFactor)
idx++
}
}
if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
@@ -1527,6 +1562,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
LoadFactor: m.LoadFactor,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
@@ -1639,3 +1675,60 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
return r.accountsToService(ctx, accounts)
}
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
rows, err := r.sql.QueryContext(ctx,
`UPDATE accounts SET extra = jsonb_set(
COALESCE(extra, '{}'::jsonb),
'{quota_used}',
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`,
amount, id)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
var newUsed, limit float64
if rows.Next() {
if err := rows.Scan(&newUsed, &limit); err != nil {
return err
}
}
if err := rows.Err(); err != nil {
return err
}
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
}
}
return nil
}
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx,
`UPDATE accounts SET extra = jsonb_set(
COALESCE(extra, '{}'::jsonb),
'{quota_used}',
'0'::jsonb
), updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`,
id)
if err != nil {
return err
}
// 重置配额后触发调度快照刷新,使账号重新参与调度
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota reset failed: account=%d err=%v", id, err)
}
return nil
}

View File

@@ -281,9 +281,27 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
// Apply filters
if filters.Search != "" {
q = q.Where(apikey.Or(
apikey.NameContainsFold(filters.Search),
apikey.KeyContainsFold(filters.Search),
))
}
if filters.Status != "" {
q = q.Where(apikey.StatusEQ(filters.Status))
}
if filters.GroupID != nil {
if *filters.GroupID == 0 {
q = q.Where(apikey.GroupIDIsNil())
} else {
q = q.Where(apikey.GroupIDEQ(*filters.GroupID))
}
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err

View File

@@ -158,7 +158,7 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
s.Require().NoError(err, "ListByUserID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
@@ -170,7 +170,7 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}, service.APIKeyListFilters{})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal(int64(5), page.Total)
@@ -314,7 +314,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
s.Require().NoError(err, "ListByUserID")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(keys, 1)

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -95,7 +96,8 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body))
return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg)
}
var usageResp service.ClaudeUsageResponse

View File

@@ -0,0 +1,183 @@
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// --- Plan Repository ---
type scheduledTestPlanRepository struct {
db *sql.DB
}
func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanRepository {
return &scheduledTestPlanRepository{db: db}
}
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
return scanPlan(row)
}
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans WHERE id = $1
`, id)
return scanPlan(row)
}
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans WHERE account_id = $1
ORDER BY created_at DESC
`, accountID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
return scanPlans(rows)
}
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans
WHERE enabled = true AND next_run_at <= $1
ORDER BY next_run_at ASC
`, now)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
return scanPlans(rows)
}
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
UPDATE scheduled_test_plans
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
WHERE id = $1
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
return scanPlan(row)
}
func (r *scheduledTestPlanRepository) Delete(ctx context.Context, id int64) error {
_, err := r.db.ExecContext(ctx, `DELETE FROM scheduled_test_plans WHERE id = $1`, id)
return err
}
func (r *scheduledTestPlanRepository) UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error {
_, err := r.db.ExecContext(ctx, `
UPDATE scheduled_test_plans SET last_run_at = $2, next_run_at = $3, updated_at = NOW() WHERE id = $1
`, id, lastRunAt, nextRunAt)
return err
}
// --- Result Repository ---
type scheduledTestResultRepository struct {
db *sql.DB
}
func NewScheduledTestResultRepository(db *sql.DB) service.ScheduledTestResultRepository {
return &scheduledTestResultRepository{db: db}
}
func (r *scheduledTestResultRepository) Create(ctx context.Context, result *service.ScheduledTestResult) (*service.ScheduledTestResult, error) {
row := r.db.QueryRowContext(ctx, `
INSERT INTO scheduled_test_results (plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
RETURNING id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
`, result.PlanID, result.Status, result.ResponseText, result.ErrorMessage, result.LatencyMs, result.StartedAt, result.FinishedAt)
out := &service.ScheduledTestResult{}
if err := row.Scan(
&out.ID, &out.PlanID, &out.Status, &out.ResponseText, &out.ErrorMessage,
&out.LatencyMs, &out.StartedAt, &out.FinishedAt, &out.CreatedAt,
); err != nil {
return nil, err
}
return out, nil
}
func (r *scheduledTestResultRepository) ListByPlanID(ctx context.Context, planID int64, limit int) ([]*service.ScheduledTestResult, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
FROM scheduled_test_results
WHERE plan_id = $1
ORDER BY created_at DESC
LIMIT $2
`, planID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var results []*service.ScheduledTestResult
for rows.Next() {
r := &service.ScheduledTestResult{}
if err := rows.Scan(
&r.ID, &r.PlanID, &r.Status, &r.ResponseText, &r.ErrorMessage,
&r.LatencyMs, &r.StartedAt, &r.FinishedAt, &r.CreatedAt,
); err != nil {
return nil, err
}
results = append(results, r)
}
return results, rows.Err()
}
func (r *scheduledTestResultRepository) PruneOldResults(ctx context.Context, planID int64, keepCount int) error {
_, err := r.db.ExecContext(ctx, `
DELETE FROM scheduled_test_results
WHERE id IN (
SELECT id FROM (
SELECT id, ROW_NUMBER() OVER (PARTITION BY plan_id ORDER BY created_at DESC) AS rn
FROM scheduled_test_results
WHERE plan_id = $1
) ranked
WHERE rn > $2
)
`, planID, keepCount)
return err
}
// --- scan helpers ---
type scannable interface {
Scan(dest ...any) error
}
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
p := &service.ScheduledTestPlan{}
if err := row.Scan(
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, err
}
return p, nil
}
func scanPlans(rows *sql.Rows) ([]*service.ScheduledTestPlan, error) {
var plans []*service.ScheduledTestPlan
for rows.Next() {
p, err := scanPlan(rows)
if err != nil {
return nil, err
}
plans = append(plans, p)
}
return plans, rows.Err()
}

View File

@@ -122,7 +122,7 @@ func (s *SettingRepoSuite) TestSet_EmptyValue() {
func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
// 模拟保存站点设置,部分字段有值,部分字段为空
settings := map[string]string{
"site_name": "AICodex2API",
"site_name": "Sub2api",
"site_subtitle": "Subscription to API",
"site_logo": "", // 用户未上传Logo
"api_base_url": "", // 用户未设置API地址
@@ -136,7 +136,7 @@ func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"})
s.Require().NoError(err, "GetMultiple after SetMultiple with empty values")
s.Require().Equal("AICodex2API", result["site_name"])
s.Require().Equal("Sub2api", result["site_name"])
s.Require().Equal("Subscription to API", result["site_subtitle"])
s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved")
s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved")

View File

@@ -1363,7 +1363,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
@@ -1401,6 +1402,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
@@ -1473,7 +1476,16 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
}
whereClause := buildWhere(conditions)
logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params)
var (
logs []service.UsageLog
page *pagination.PaginationResult
err error
)
if shouldUseFastUsageLogTotal(filters) {
logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params)
} else {
logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params)
}
if err != nil {
return nil, nil, err
}
@@ -1484,17 +1496,45 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
return logs, page, nil
}
func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool {
if filters.ExactTotal {
return false
}
// 强选择过滤下记录集通常较小,保留精确总数。
return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0
}
// UsageStats represents usage statistics
type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats = usagestats.BatchUserUsageStats
func normalizePositiveInt64IDs(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
seen := make(map[int64]struct{}, len(ids))
out := make([]int64, 0, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
return out
}
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
// If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
result := make(map[int64]*BatchUserUsageStats)
if len(userIDs) == 0 {
normalizedUserIDs := normalizePositiveInt64IDs(userIDs)
if len(normalizedUserIDs) == 0 {
return result, nil
}
@@ -1506,58 +1546,36 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
endTime = time.Now()
}
for _, id := range userIDs {
for _, id := range normalizedUserIDs {
result[id] = &BatchUserUsageStats{UserID: id}
}
query := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
SELECT
user_id,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
FROM usage_logs
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
WHERE user_id = ANY($1)
AND created_at >= LEAST($2, $4)
GROUP BY user_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
today := timezone.Today()
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today)
if err != nil {
return nil, err
}
for rows.Next() {
var userID int64
var total float64
if err := rows.Scan(&userID, &total); err != nil {
var todayTotal float64
if err := rows.Scan(&userID, &total, &todayTotal); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[userID]; ok {
stats.TotalActualCost = total
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost
FROM usage_logs
WHERE user_id = ANY($1) AND created_at >= $2
GROUP BY user_id
`
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today)
if err != nil {
return nil, err
}
for rows.Next() {
var userID int64
var total float64
if err := rows.Scan(&userID, &total); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[userID]; ok {
stats.TodayActualCost = total
stats.TodayActualCost = todayTotal
}
}
if err := rows.Close(); err != nil {
@@ -1577,7 +1595,8 @@ type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 {
normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs)
if len(normalizedAPIKeyIDs) == 0 {
return result, nil
}
@@ -1589,58 +1608,36 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
endTime = time.Now()
}
for _, id := range apiKeyIDs {
for _, id := range normalizedAPIKeyIDs {
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
}
query := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
SELECT
api_key_id,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
FROM usage_logs
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
WHERE api_key_id = ANY($1)
AND created_at >= LEAST($2, $4)
GROUP BY api_key_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
today := timezone.Today()
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today)
if err != nil {
return nil, err
}
for rows.Next() {
var apiKeyID int64
var total float64
if err := rows.Scan(&apiKeyID, &total); err != nil {
var todayTotal float64
if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[apiKeyID]; ok {
stats.TotalActualCost = total
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost
FROM usage_logs
WHERE api_key_id = ANY($1) AND created_at >= $2
GROUP BY api_key_id
`
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today)
if err != nil {
return nil, err
}
for rows.Next() {
var apiKeyID int64
var total float64
if err := rows.Scan(&apiKeyID, &total); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[apiKeyID]; ok {
stats.TodayActualCost = total
stats.TodayActualCost = todayTotal
}
}
if err := rows.Close(); err != nil {
@@ -1670,7 +1667,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
@@ -1753,7 +1751,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
total_requests as requests,
input_tokens,
output_tokens,
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
cache_creation_tokens,
cache_read_tokens,
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
total_cost as cost,
actual_cost
@@ -1768,7 +1767,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
total_requests as requests,
input_tokens,
output_tokens,
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
cache_creation_tokens,
cache_read_tokens,
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
total_cost as cost,
actual_cost
@@ -1812,6 +1812,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
@@ -2245,6 +2247,35 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
return logs, paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
limit := params.Limit()
offset := params.Offset()
limitPos := len(args) + 1
offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), limit+1, offset)
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil {
return nil, nil, err
}
hasMore := false
if len(logs) > limit {
hasMore = true
logs = logs[:limit]
}
total := int64(offset) + int64(len(logs))
if hasMore {
// 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。
total = int64(offset) + int64(limit) + 1
}
return logs, paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
@@ -2599,7 +2630,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
&row.Requests,
&row.InputTokens,
&row.OutputTokens,
&row.CacheTokens,
&row.CacheCreationTokens,
&row.CacheReadTokens,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
@@ -2623,6 +2655,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
&row.Requests,
&row.InputTokens,
&row.OutputTokens,
&row.CacheCreationTokens,
&row.CacheReadTokens,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,

View File

@@ -96,6 +96,7 @@ func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
filters := usagestats.UsageLogFilters{
RequestType: &requestType,
Stream: &stream,
ExactTotal: true,
}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
@@ -124,7 +125,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
require.NoError(t, err)
@@ -143,7 +144,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
require.NoError(t, err)

View File

@@ -243,21 +243,24 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
userMap[u.ID] = &outUsers[len(outUsers)-1]
}
// Batch load active subscriptions with groups to avoid N+1.
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDIn(userIDs...),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
).
WithGroup().
All(ctx)
if err != nil {
return nil, nil, err
}
shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions
if shouldLoadSubscriptions {
// Batch load active subscriptions with groups to avoid N+1.
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDIn(userIDs...),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
).
WithGroup().
All(ctx)
if err != nil {
return nil, nil, err
}
for i := range subs {
if u, ok := userMap[subs[i].UserID]; ok {
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
for i := range subs {
if u, ok := userMap[subs[i].UserID]; ok {
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
}
}
}

View File

@@ -53,7 +53,9 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewSoraAccountRepository, // Sora 账号扩展表仓储
NewSoraAccountRepository, // Sora 账号扩展表仓储
NewScheduledTestPlanRepository, // 定时测试计划仓储
NewScheduledTestResultRepository, // 定时测试结果仓储
NewProxyRepository,
NewRedeemCodeRepository,
NewPromoCodeRepository,

View File

@@ -446,9 +446,10 @@ func TestAPIContracts(t *testing.T) {
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
@@ -487,6 +488,7 @@ func TestAPIContracts(t *testing.T) {
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"registration_email_suffix_whitelist": [],
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
@@ -1094,6 +1096,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map
return errors.New("not implemented")
}
func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
s.bulkUpdateIDs = append([]int64{}, ids...)
return int64(len(ids)), nil
@@ -1411,7 +1421,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return nil
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID))
for id := range r.byID {
if r.byID[id].UserID == userID {

View File

@@ -56,7 +56,7 @@ func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {

View File

@@ -537,7 +537,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}

View File

@@ -78,6 +78,9 @@ func RegisterAdminRoutes(
// API Key 管理
registerAdminAPIKeyRoutes(admin, h)
// 定时测试计划
registerScheduledTestRoutes(admin, h)
}
}
@@ -168,6 +171,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
// Dashboard (vNext - raw path for MVP)
ops.GET("/dashboard/snapshot-v2", h.Admin.Ops.GetDashboardSnapshotV2)
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
@@ -180,6 +184,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/snapshot-v2", h.Admin.Dashboard.GetSnapshotV2)
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
@@ -247,6 +252,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
@@ -476,6 +482,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
plans := admin.Group("/scheduled-test-plans")
{
plans.POST("", h.Admin.ScheduledTest.Create)
plans.PUT("/:id", h.Admin.ScheduledTest.Update)
plans.DELETE("/:id", h.Admin.ScheduledTest.Delete)
plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults)
}
// Nested under accounts
admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount)
}
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
rules := admin.Group("/error-passthrough-rules")
{

View File

@@ -28,6 +28,7 @@ type Account struct {
// RateMultiplier 账号计费倍率(>=0允许 0 表示该账号计费为 0
// 使用指针用于兼容旧版本调度缓存Redis中缺字段的情况nil 表示按 1.0 处理。
RateMultiplier *float64
LoadFactor *int // 调度负载因子nil 表示使用 Concurrency
Status string
ErrorMessage string
LastUsedAt *time.Time
@@ -88,6 +89,19 @@ func (a *Account) BillingRateMultiplier() float64 {
return *a.RateMultiplier
}
func (a *Account) EffectiveLoadFactor() int {
if a == nil {
return 1
}
if a.LoadFactor != nil && *a.LoadFactor > 0 {
return *a.LoadFactor
}
if a.Concurrency > 0 {
return a.Concurrency
}
return 1
}
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
@@ -853,15 +867,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
}
const (
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
OpenAIWSIngressModeCtxPool = "ctx_pool"
OpenAIWSIngressModePassthrough = "passthrough"
)
func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff
case OpenAIWSIngressModeCtxPool:
return OpenAIWSIngressModeCtxPool
case OpenAIWSIngressModePassthrough:
return OpenAIWSIngressModePassthrough
case OpenAIWSIngressModeShared:
return OpenAIWSIngressModeShared
case OpenAIWSIngressModeDedicated:
@@ -873,18 +893,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
return OpenAIWSIngressModeCtxPool
}
return normalized
}
return OpenAIWSIngressModeShared
return OpenAIWSIngressModeCtxPool
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式off/shared/dedicated)。
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式off/ctx_pool/passthrough)。
//
// 优先级:
// 1. 分类型 mode 新字段string
// 2. 分类型 enabled 旧字段bool
// 3. 兼容 enabled 旧字段bool
// 4. defaultMode非法时回退 shared
// 4. defaultMode非法时回退 ctx_pool
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() {
@@ -919,7 +942,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return "", false
}
if enabled {
return OpenAIWSIngressModeShared, true
return OpenAIWSIngressModeCtxPool, true
}
return OpenAIWSIngressModeOff, true
}
@@ -946,6 +969,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
return mode
}
// 兼容旧值shared/dedicated 语义都归并到 ctx_pool。
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
return OpenAIWSIngressModeCtxPool
}
return resolvedDefault
}
@@ -1104,6 +1131,38 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
return "5m"
}
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
// 返回 0 表示未启用
func (a *Account) GetQuotaLimit() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["quota_limit"]; ok {
return parseExtraFloat64(v)
}
return 0
}
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
func (a *Account) GetQuotaUsed() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["quota_used"]; ok {
return parseExtraFloat64(v)
}
return 0
}
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
func (a *Account) IsQuotaExceeded() bool {
limit := a.GetQuotaLimit()
if limit <= 0 {
return false
}
return a.GetQuotaUsed() >= limit
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {

View File

@@ -0,0 +1,46 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func intPtrHelper(v int) *int { return &v }
func TestEffectiveLoadFactor_NilAccount(t *testing.T) {
var a *Account
require.Equal(t, 1, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) {
a := &Account{Concurrency: 5}
require.Equal(t, 5, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) {
a := &Account{Concurrency: 0}
require.Equal(t, 1, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) {
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(20)}
require.Equal(t, 20, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) {
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(0)}
require.Equal(t, 5, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) {
a := &Account{Concurrency: 3, LoadFactor: intPtrHelper(-1)}
require.Equal(t, 3, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) {
a := &Account{Concurrency: 0, LoadFactor: intPtrHelper(0)}
require.Equal(t, 1, a.EffectiveLoadFactor())
}

View File

@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
}
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
t.Run("default fallback to shared", func(t *testing.T) {
t.Run("default fallback to ctx_pool", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
})
t.Run("oauth mode field has highest priority", func(t *testing.T) {
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": false,
},
}
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
t.Run("legacy enabled maps to shared", func(t *testing.T) {
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
shared := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
},
}
dedicated := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
})
t.Run("legacy disabled maps to off", func(t *testing.T) {
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
t.Run("non openai always off", func(t *testing.T) {

View File

@@ -68,6 +68,10 @@ type AccountRepository interface {
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
ResetQuotaUsed(ctx context.Context, id int64) error
}
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
@@ -78,6 +82,7 @@ type AccountBulkUpdate struct {
Concurrency *int
Priority *int
RateMultiplier *float64
LoadFactor *int
Status *string
Schedulable *bool
Credentials map[string]any

View File

@@ -199,6 +199,14 @@ func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates A
panic("unexpected BulkUpdate call")
}
func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
// TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。
// 预期行为:
// - ExistsByID 返回 false账号不存在

View File

@@ -12,6 +12,7 @@ import (
"io"
"log"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
@@ -33,7 +34,7 @@ import (
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
@@ -179,7 +180,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
}
if account.Platform == PlatformAntigravity {
return s.testAntigravityAccountConnection(c, account, modelID)
return s.routeAntigravityTest(c, account, modelID)
}
if account.Platform == PlatformSora {
@@ -238,7 +239,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages?beta=true"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -1176,6 +1177,18 @@ func truncateSoraErrorBody(body []byte, max int) string {
return soraerror.TruncateBody(body, max)
}
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
if account.Type == AccountTypeAPIKey {
if strings.HasPrefix(modelID, "gemini-") {
return s.testGeminiAccountConnection(c, account, modelID)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
return s.testAntigravityAccountConnection(c, account, modelID)
}
// testAntigravityAccountConnection tests an Antigravity account's connection
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
@@ -1560,3 +1573,62 @@ func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) er
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
return fmt.Errorf("%s", errorMsg)
}
// RunTestBackground executes an account test in-memory (no real HTTP client),
// capturing SSE output via httptest.NewRecorder, then parses the result.
func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID int64, modelID string) (*ScheduledTestResult, error) {
startedAt := time.Now()
w := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
finishedAt := time.Now()
body := w.Body.String()
responseText, errMsg := parseTestSSEOutput(body)
status := "success"
if testErr != nil || errMsg != "" {
status = "failed"
if errMsg == "" && testErr != nil {
errMsg = testErr.Error()
}
}
return &ScheduledTestResult{
Status: status,
ResponseText: responseText,
ErrorMessage: errMsg,
LatencyMs: finishedAt.Sub(startedAt).Milliseconds(),
StartedAt: startedAt,
FinishedAt: finishedAt,
}, nil
}
// parseTestSSEOutput extracts response text and error message from captured SSE output.
func parseTestSSEOutput(body string) (responseText, errMsg string) {
var texts []string
for _, line := range strings.Split(body, "\n") {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimPrefix(line, "data: ")
var event TestEvent
if err := json.Unmarshal([]byte(jsonStr), &event); err != nil {
continue
}
switch event.Type {
case "content":
if event.Text != "" {
texts = append(texts, event.Text)
}
case "error":
errMsg = event.Error
}
}
responseText = strings.Join(texts, "")
return
}

View File

@@ -84,6 +84,7 @@ type AdminService interface {
DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
ResetAccountQuota(ctx context.Context, id int64) error
}
// CreateUserInput represents input for creating a new user via admin operations.
@@ -195,6 +196,7 @@ type CreateAccountInput struct {
Concurrency int
Priority int
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
LoadFactor *int
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
@@ -215,6 +217,7 @@ type UpdateAccountInput struct {
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
LoadFactor *int
Status string
GroupIDs *[]int64
ExpiresAt *int64
@@ -230,6 +233,7 @@ type BulkUpdateAccountsInput struct {
Concurrency *int
Priority *int
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
LoadFactor *int
Status string
Schedulable *bool
GroupIDs *[]int64
@@ -745,7 +749,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
if err != nil {
return nil, 0, err
}
@@ -1413,6 +1417,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
account.RateMultiplier = input.RateMultiplier
}
if input.LoadFactor != nil && *input.LoadFactor > 0 {
if *input.LoadFactor > 10000 {
return nil, errors.New("load_factor must be <= 10000")
}
account.LoadFactor = input.LoadFactor
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
@@ -1458,6 +1468,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
// 保留 quota_used防止编辑账号时意外重置配额用量
if oldQuotaUsed, ok := account.Extra["quota_used"]; ok {
input.Extra["quota_used"] = oldQuotaUsed
}
account.Extra = input.Extra
}
if input.ProxyID != nil {
@@ -1483,6 +1497,15 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
}
account.RateMultiplier = input.RateMultiplier
}
if input.LoadFactor != nil {
if *input.LoadFactor <= 0 {
account.LoadFactor = nil // 0 或负数表示清除
} else if *input.LoadFactor > 10000 {
return nil, errors.New("load_factor must be <= 10000")
} else {
account.LoadFactor = input.LoadFactor
}
}
if input.Status != "" {
account.Status = input.Status
}
@@ -1616,6 +1639,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.RateMultiplier != nil {
repoUpdates.RateMultiplier = input.RateMultiplier
}
if input.LoadFactor != nil {
if *input.LoadFactor <= 0 {
repoUpdates.LoadFactor = nil // 0 或负数表示清除
} else if *input.LoadFactor > 10000 {
return nil, errors.New("load_factor must be <= 10000")
} else {
repoUpdates.LoadFactor = input.LoadFactor
}
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
@@ -2439,3 +2471,7 @@ func (e *MixedChannelError) Error() string {
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
}
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
return s.accountRepo.ResetQuotaUsed(ctx, id)
}

View File

@@ -91,7 +91,7 @@ func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string)
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {

View File

@@ -97,3 +97,10 @@ func (k *APIKey) GetDaysUntilExpiry() int {
}
return int(duration.Hours() / 24)
}
// APIKeyListFilters holds optional filtering parameters for listing API keys.
type APIKeyListFilters struct {
Search string
Status string
GroupID *int64 // nil=不筛选, 0=无分组, >0=指定分组
}

View File

@@ -55,7 +55,7 @@ type APIKeyRepository interface {
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
@@ -392,8 +392,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
}
// List 获取用户的API Key列表
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
}

View File

@@ -53,7 +53,7 @@ func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}

View File

@@ -81,7 +81,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net/mail"
"strconv"
"strings"
"time"
@@ -33,6 +34,7 @@ var (
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
@@ -115,6 +117,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
if isReservedEmail(email) {
return "", nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return "", nil, err
}
// 检查是否需要邀请码
var invitationRedeemCode *RedeemCode
@@ -241,6 +246,9 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
if isReservedEmail(email) {
return ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return err
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
@@ -279,6 +287,9 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
if isReservedEmail(email) {
return nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return nil, err
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
@@ -624,6 +635,32 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
}
}
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
}
whitelist := s.settingService.GetRegistrationEmailSuffixWhitelist(ctx)
if !IsRegistrationEmailSuffixAllowed(email, whitelist) {
return buildEmailSuffixNotAllowedError(whitelist)
}
return nil
}
func buildEmailSuffixNotAllowedError(whitelist []string) error {
if len(whitelist) == 0 {
return ErrEmailSuffixNotAllowed
}
allowed := strings.Join(whitelist, ", ")
return infraerrors.BadRequest(
"EMAIL_SUFFIX_NOT_ALLOWED",
fmt.Sprintf("email suffix is not allowed, allowed suffixes: %s", allowed),
).WithMetadata(map[string]string{
"allowed_suffixes": strings.Join(whitelist, ","),
"allowed_suffix_count": strconv.Itoa(len(whitelist)),
})
}
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token降低 DoS 风险。

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
@@ -231,6 +232,51 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) {
require.ErrorIs(t, err, ErrEmailReserved)
}
func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
}, nil)
_, _, err := service.Register(context.Background(), "user@other.com", "password")
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
appErr := infraerrors.FromError(err)
require.Contains(t, appErr.Message, "@example.com")
require.Contains(t, appErr.Message, "@company.com")
require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason)
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"])
}
func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) {
repo := &userRepoStub{nextID: 8}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`,
}, nil)
_, user, err := service.Register(context.Background(), "user@example.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, int64(8), user.ID)
}
func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
}, nil)
err := service.SendVerifyCode(context.Background(), "user@other.com")
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
appErr := infraerrors.FromError(err)
require.Contains(t, appErr.Message, "@example.com")
require.Contains(t, appErr.Message, "@company.com")
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
}
func TestAuthService_Register_CreateError(t *testing.T) {
repo := &userRepoStub{createErr: errors.New("create failed")}
service := newAuthService(repo, map[string]string{
@@ -402,7 +448,7 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
}, nil)
service.defaultSubAssigner = assigner

View File

@@ -43,15 +43,24 @@ type BillingCache interface {
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致
type ModelPricing struct {
InputPricePerToken float64 // 每token输入价格 (USD)
OutputPricePerToken float64 // 每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
InputPricePerToken float64 // 每token输入价格 (USD)
OutputPricePerToken float64 // 每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
}
const (
openAIGPT54LongContextInputThreshold = 272000
openAIGPT54LongContextInputMultiplier = 2.0
openAIGPT54LongContextOutputMultiplier = 1.5
)
// UsageTokens 使用的token数量
type UsageTokens struct {
InputTokens int
@@ -161,6 +170,35 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok
SupportsCacheBreakdown: false,
}
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
InputPricePerToken: 1.25e-6, // $1.25 per MTok
OutputPricePerToken: 10e-6, // $10 per MTok
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
CacheReadPricePerToken: 0.125e-6,
SupportsCacheBreakdown: false,
}
// OpenAI GPT-5.4(业务指定价格)
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
InputPricePerToken: 2.5e-6, // $2.5 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
SupportsCacheBreakdown: false,
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
}
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
InputPricePerToken: 1.5e-6, // $1.5 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
CacheReadPricePerToken: 0.15e-6,
SupportsCacheBreakdown: false,
}
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
}
// getFallbackPricing 根据模型系列获取回退价格
@@ -189,12 +227,30 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
}
return s.fallbackPrices["claude-3-haiku"]
}
// Claude 未知型号统一回退到 Sonnet避免计费中断。
if strings.Contains(modelLower, "claude") {
return s.fallbackPrices["claude-sonnet-4"]
}
if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") {
return s.fallbackPrices["gemini-3.1-pro"]
}
// 默认使用Sonnet价格
return s.fallbackPrices["claude-sonnet-4"]
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
normalized := normalizeCodexModel(modelLower)
switch normalized {
case "gpt-5.4":
return s.fallbackPrices["gpt-5.4"]
case "gpt-5.3-codex":
return s.fallbackPrices["gpt-5.3-codex"]
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
return s.fallbackPrices["gpt-5.1-codex"]
case "gpt-5.1":
return s.fallbackPrices["gpt-5.1"]
}
}
return nil
}
// GetModelPricing 获取模型价格配置
@@ -212,15 +268,18 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
price5m := litellmPricing.CacheCreationInputTokenCost
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
enableBreakdown := price1h > 0 && price1h > price5m
return &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
CacheCreation5mPrice: price5m,
CacheCreation1hPrice: price1h,
SupportsCacheBreakdown: enableBreakdown,
}, nil
return s.applyModelSpecificPricingPolicy(model, &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
CacheCreation5mPrice: price5m,
CacheCreation1hPrice: price1h,
SupportsCacheBreakdown: enableBreakdown,
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
}), nil
}
}
@@ -228,7 +287,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
fallback := s.getFallbackPricing(model)
if fallback != nil {
log.Printf("[Billing] Using fallback pricing for model: %s", model)
return fallback, nil
return s.applyModelSpecificPricingPolicy(model, fallback), nil
}
return nil, fmt.Errorf("pricing not found for model: %s", model)
@@ -242,12 +301,18 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
}
breakdown := &CostBreakdown{}
inputPricePerToken := pricing.InputPricePerToken
outputPricePerToken := pricing.OutputPricePerToken
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
inputPricePerToken *= pricing.LongContextInputMultiplier
outputPricePerToken *= pricing.LongContextOutputMultiplier
}
// 计算输入token费用使用per-token价格
breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
// 计算输出token费用
breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken
// 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
@@ -279,6 +344,45 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
return breakdown, nil
}
func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing {
if pricing == nil {
return nil
}
if !isOpenAIGPT54Model(model) {
return pricing
}
if pricing.LongContextInputThreshold > 0 && pricing.LongContextInputMultiplier > 0 && pricing.LongContextOutputMultiplier > 0 {
return pricing
}
cloned := *pricing
if cloned.LongContextInputThreshold <= 0 {
cloned.LongContextInputThreshold = openAIGPT54LongContextInputThreshold
}
if cloned.LongContextInputMultiplier <= 0 {
cloned.LongContextInputMultiplier = openAIGPT54LongContextInputMultiplier
}
if cloned.LongContextOutputMultiplier <= 0 {
cloned.LongContextOutputMultiplier = openAIGPT54LongContextOutputMultiplier
}
return &cloned
}
func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens, pricing *ModelPricing) bool {
if pricing == nil || pricing.LongContextInputThreshold <= 0 {
return false
}
if pricing.LongContextInputMultiplier <= 1 && pricing.LongContextOutputMultiplier <= 1 {
return false
}
totalInputTokens := tokens.InputTokens + tokens.CacheReadTokens
return totalInputTokens > pricing.LongContextInputThreshold
}
func isOpenAIGPT54Model(model string) bool {
normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model)))
return normalized == "gpt-5.4"
}
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
multiplier := s.cfg.Default.RateMultiplier

View File

@@ -133,7 +133,7 @@ func TestGetModelPricing_CaseInsensitive(t *testing.T) {
require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken)
}
func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) {
func TestGetModelPricing_UnknownClaudeModelFallsBackToSonnet(t *testing.T) {
svc := newTestBillingService()
// 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格
@@ -142,6 +142,93 @@ func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) {
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
}
func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-unknown-model")
require.Error(t, err)
require.Nil(t, pricing)
require.Contains(t, err.Error(), "pricing not found")
}
func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-5.1")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
}
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-5.4")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
require.Equal(t, 272000, pricing.LongContextInputThreshold)
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
}
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{
InputTokens: 300000,
OutputTokens: 4000,
}
cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0)
require.NoError(t, err)
expectedInput := float64(tokens.InputTokens) * 2.5e-6 * 2.0
expectedOutput := float64(tokens.OutputTokens) * 15e-6 * 1.5
require.InDelta(t, expectedInput, cost.InputCost, 1e-10)
require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10)
}
func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
svc := newTestBillingService()
tests := []struct {
name string
model string
expectedInput float64
expectNilPricing bool
}{
{name: "empty model", model: " ", expectNilPricing: true},
{name: "claude opus 4.6", model: "claude-opus-4.6-20260201", expectedInput: 5e-6},
{name: "claude opus 4.5 alt separator", model: "claude-opus-4-5-20260101", expectedInput: 5e-6},
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
{name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
{name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6},
{name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6},
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
{name: "non supported family", model: "qwen-max", expectNilPricing: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pricing := svc.getFallbackPricing(tt.model)
if tt.expectNilPricing {
require.Nil(t, pricing)
return
}
require.NotNil(t, pricing)
require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12)
})
}
}
func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) {
svc := newTestBillingService()

View File

@@ -74,11 +74,12 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单JSON 数组)
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址

View File

@@ -88,6 +88,49 @@ func TestCheckErrorPolicy(t *testing.T) {
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_401_first_hit_returns_temp_unscheduled",
account: &Account{
ID: 14,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
account: &Account{
ID: 15,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyNone,
},
{
name: "temp_unschedulable_body_miss_returns_none",
account: &Account{

View File

@@ -171,8 +171,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体")
require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射")
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
@@ -190,7 +189,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
require.True(t, ok)
bodyBytes, ok := rawBody.([]byte)
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
require.Equal(t, body, bodyBytes)
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
@@ -253,8 +252,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体")
require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射")
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
@@ -263,6 +261,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
require.Empty(t, rec.Header().Get("Set-Cookie"))
}
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
model string
modelMapping map[string]any // nil = 不配置映射
expectedModel string
endpoint string // "messages" or "count_tokens"
}{
{
name: "Forward: 无映射配置时不改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: nil,
expectedModel: "claude-sonnet-4-20250514",
endpoint: "messages",
},
{
name: "Forward: 空映射配置时不改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{},
expectedModel: "claude-sonnet-4-20250514",
endpoint: "messages",
},
{
name: "Forward: 模型不在映射表中时不改写",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
expectedModel: "claude-sonnet-4-20250514",
endpoint: "messages",
},
{
name: "Forward: 精确匹配映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "messages",
},
{
name: "Forward: 通配符映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "messages",
},
{
name: "CountTokens: 无映射配置时不改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: nil,
expectedModel: "claude-sonnet-4-20250514",
endpoint: "count_tokens",
},
{
name: "CountTokens: 模型不在映射表中时不改写",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
expectedModel: "claude-sonnet-4-20250514",
endpoint: "count_tokens",
},
{
name: "CountTokens: 精确匹配映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "count_tokens",
},
{
name: "CountTokens: 通配符映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "count_tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
parsed := &ParsedRequest{
Body: body,
Model: tt.model,
}
credentials := map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
}
if tt.modelMapping != nil {
credentials["model_mapping"] = tt.modelMapping
}
account := &Account{
ID: 300,
Name: "edge-case-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: credentials,
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
if tt.endpoint == "messages" {
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
parsed.Stream = false
upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamJSON)),
},
}
svc := &GatewayService{
cfg: &config.Config{},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
result, err := svc.Forward(context.Background(), c, account, parsed)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
"Forward 上游请求体中的模型应为: %s", tt.expectedModel)
} else {
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
upstreamRespBody := `{"input_tokens":42}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
"CountTokens 上游请求体中的模型应为: %s", tt.expectedModel)
}
})
}
}
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields
// 确保模型映射只替换 model 字段,不影响请求体中的其他字段
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
// 包含复杂字段的请求体system、thinking、messages
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`)
parsed := &ParsedRequest{
Body: body,
Model: "claude-sonnet-4-20250514",
}
upstreamRespBody := `{"input_tokens":42}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 301,
Name: "preserve-fields-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
"model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
},
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
sentBody := upstream.lastBody
require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射")
require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改")
require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改")
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改")
require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改")
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
}
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
// 确保空模型名不会触发映射逻辑
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
parsed := &ParsedRequest{
Body: body,
Model: "", // 空模型
}
upstreamRespBody := `{"input_tokens":10}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 302,
Name: "empty-model-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
"model_mapping": map[string]any{"*": "claude-3-opus-20240229"},
},
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
// 空模型名时body 应原样透传,不应触发映射
require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -187,6 +187,14 @@ func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64
return 0, nil
}
func (m *mockAccountRepoForPlatform) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)

View File

@@ -1228,6 +1228,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
continue
}
// 配额检查
if !s.isAccountSchedulableForQuota(account) {
continue
}
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
filteredWindowCost++
@@ -1260,6 +1264,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
s.isAccountSchedulableForQuota(stickyAccount) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
@@ -1311,7 +1316,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, acc := range routingCandidates {
routingLoads = append(routingLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
@@ -1416,6 +1421,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
s.isAccountSchedulableForQuota(account) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) &&
s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
@@ -1480,6 +1486,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
// 配额检查
if !s.isAccountSchedulableForQuota(acc) {
continue
}
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
@@ -1499,7 +1509,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
@@ -2113,6 +2123,15 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
}
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
// 仅适用于配置了 quota_limit 的 apikey 类型账号
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
if account.Type != AccountTypeAPIKey {
return true
}
return !account.IsQuotaExceeded()
}
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度false 表示不可调度
@@ -2590,7 +2609,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2644,6 +2663,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
@@ -2700,7 +2722,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
return account, nil
}
}
@@ -2743,6 +2765,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
@@ -2818,7 +2843,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
@@ -2874,6 +2899,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
@@ -2930,7 +2958,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
@@ -2975,6 +3003,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
@@ -3889,7 +3920,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime)
passthroughBody := parsed.Body
passthroughModel := parsed.Model
if passthroughModel != "" {
if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel {
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
passthroughModel = mappedModel
}
}
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
}
body := parsed.Body
@@ -4574,7 +4614,7 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
targetURL = validatedURL + "/v1/messages?beta=true"
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
@@ -4954,7 +4994,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
targetURL = validatedURL + "/v1/messages?beta=true"
}
}
@@ -6370,6 +6410,89 @@ type APIKeyQuotaUpdater interface {
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
}
// postUsageBillingParams 统一扣费所需的参数
type postUsageBillingParams struct {
Cost *CostBreakdown
User *User
APIKey *APIKey
Account *Account
Subscription *UserSubscription
IsSubscriptionBill bool
AccountRateMultiplier float64
APIKeyService APIKeyQuotaUpdater
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费
// - API Key 配额更新
// - API Key 限速用量更新
// - 账号配额用量更新账号口径TotalCost × 账号计费倍率)
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
cost := p.Cost
// 1. 订阅 / 余额扣费
if p.IsSubscriptionBill {
if cost.TotalCost > 0 {
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
}
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
}
}
// 2. API Key 配额
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 3. API Key 限速用量
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
}
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
}
// 4. 账号配额用量账号口径TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
}
}
// 5. 更新账号最近使用时间
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
}
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type billingDeps struct {
accountRepo AccountRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
deferredService *DeferredService
}
func (s *GatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
}
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
@@ -6533,45 +6656,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}
// 更新 API Key 配额(如果设置了配额限制)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err)
}
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
@@ -6731,44 +6830,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
// API Key 独立配额扣费
if input.APIKeyService != nil && apiKey.Quota > 0 {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err)
}
}
}
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
@@ -6781,7 +6857,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body)
passthroughBody := parsed.Body
if reqModel := parsed.Model; reqModel != "" {
if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel {
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
}
}
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody)
}
body := parsed.Body
@@ -7072,7 +7155,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
@@ -7119,7 +7202,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
}
}

View File

@@ -122,6 +122,28 @@ func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "gemini_apikey_temp_unschedulable_401_second_hit_returns_none",
account: &Account{
ID: 105,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyNone,
},
{
name: "gemini_custom_codes_override_temp_unschedulable",
account: &Account{

View File

@@ -176,6 +176,14 @@ func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64,
return 0, nil
}
func (m *mockAccountRepoForGemini) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (m *mockAccountRepoForGemini) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)

View File

@@ -19,8 +19,10 @@ import (
// 预编译正则表达式(避免每次调用重新编译)
var (
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`)
// 匹配 user_id 格式:
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
)
@@ -239,13 +241,16 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return body, nil
}
// 匹配格式: user_{64位hex}_account__session_{uuid}
// 匹配格式:
// 旧格式: user_{64位hex}_account__session_{uuid}
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
return body, nil
}
sessionTail := matches[1] // 原始session UUID
// matches[1] = account UUID (可能为空), matches[2] = session UUID
sessionTail := matches[2] // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)

View File

@@ -342,6 +342,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
if s.service.concurrencyService != nil {
return &AccountSelectionResult{
Account: account,
@@ -590,7 +591,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
filtered = append(filtered, account)
loadReq = append(loadReq, AccountWithConcurrency{
ID: account.ID,
MaxConcurrency: account.Concurrency,
MaxConcurrency: account.EffectiveLoadFactor(),
})
}
if len(filtered) == 0 {
@@ -703,6 +704,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
candidate := selectionOrder[0]
return &AccountSelectionResult{
Account: candidate.account,

View File

@@ -9,6 +9,13 @@ import (
var codexCLIInstructions string
var codexModelMap = map[string]string{
"gpt-5.4": "gpt-5.4",
"gpt-5.4-none": "gpt-5.4",
"gpt-5.4-low": "gpt-5.4",
"gpt-5.4-medium": "gpt-5.4",
"gpt-5.4-high": "gpt-5.4",
"gpt-5.4-xhigh": "gpt-5.4",
"gpt-5.4-chat-latest": "gpt-5.4",
"gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-none": "gpt-5.3-codex",
"gpt-5.3-low": "gpt-5.3-codex",
@@ -154,6 +161,9 @@ func normalizeCodexModel(model string) string {
normalized := strings.ToLower(modelID)
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") {
return "gpt-5.4"
}
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
return "gpt-5.2-codex"
}

View File

@@ -167,6 +167,10 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
cases := map[string]string{
"gpt-5.4": "gpt-5.4",
"gpt-5.4-high": "gpt-5.4",
"gpt-5.4-chat-latest": "gpt-5.4",
"gpt 5.4": "gpt-5.4",
"gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",

View File

@@ -263,13 +263,15 @@ type OpenAIGatewayService struct {
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
openaiSchedulerOnce sync.Once
openaiWSPool *openAIWSConnPool
openaiWSStateStore OpenAIWSStateStore
openaiScheduler OpenAIAccountScheduler
openaiAccountStats *openAIAccountRuntimeStats
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
openaiSchedulerOnce sync.Once
openaiWSPassthroughDialerOnce sync.Once
openaiWSPool *openAIWSConnPool
openaiWSStateStore OpenAIWSStateStore
openaiScheduler OpenAIAccountScheduler
openaiWSPassthroughDialer openAIWSClientDialer
openaiAccountStats *openAIAccountRuntimeStats
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiWSRetryMetrics openAIWSRetryMetrics
@@ -317,6 +319,16 @@ func NewOpenAIGatewayService(
return svc
}
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
}
}
// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
// 应在应用优雅关闭时调用。
func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
@@ -1240,7 +1252,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
@@ -3472,37 +3484,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
shouldBill := inserted || err != nil
// Deduct based on billing type
if isSubscriptionBilling {
if shouldBill && cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
if shouldBill && cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}
// Update API key quota if applicable (only for balance mode with quota set)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
}
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}

View File

@@ -11,6 +11,7 @@ import (
"sync/atomic"
"time"
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws "github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
@@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct {
conn *coderws.Conn
}
var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil)
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
@@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
}
}
func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.conn == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
msgType, payload, err := c.conn.Read(ctx)
if err != nil {
return coderws.MessageText, nil, err
}
return msgType, payload, nil
}
func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Write(ctx, msgType, payload)
}
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed

View File

@@ -46,9 +46,10 @@ const (
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
openAIWSPayloadSizeEstimateMaxItems = 16
openAIWSEventFlushBatchSizeDefault = 4
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
openAIWSPayloadLogSampleDefault = 0.2
openAIWSEventFlushBatchSizeDefault = 4
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
openAIWSPayloadLogSampleDefault = 0.2
openAIWSPassthroughIdleTimeoutDefault = time.Hour
openAIWSStoreDisabledConnModeStrict = "strict"
openAIWSStoreDisabledConnModeAdaptive = "adaptive"
@@ -863,7 +864,8 @@ func isOpenAIWSClientDisconnectError(err error) bool {
strings.Contains(message, "unexpected eof") ||
strings.Contains(message, "use of closed network connection") ||
strings.Contains(message, "connection reset by peer") ||
strings.Contains(message, "broken pipe")
strings.Contains(message, "broken pipe") ||
strings.Contains(message, "an established connection was aborted")
}
func classifyOpenAIWSReadFallbackReason(err error) string {
@@ -904,6 +906,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
return s.openaiWSPool
}
func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer {
if s == nil {
return nil
}
s.openaiWSPassthroughDialerOnce.Do(func() {
if s.openaiWSPassthroughDialer == nil {
s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer()
}
})
return s.openaiWSPassthroughDialer
}
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
pool := s.getOpenAIWSConnPool()
if pool == nil {
@@ -967,6 +981,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
return 15 * time.Minute
}
func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration {
if timeout := s.openAIWSReadTimeout(); timeout > 0 {
return timeout
}
return openAIWSPassthroughIdleTimeoutDefault
}
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
@@ -2322,7 +2343,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
ingressMode := OpenAIWSIngressModeShared
ingressMode := OpenAIWSIngressModeCtxPool
if modeRouterV2Enabled {
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
if ingressMode == OpenAIWSIngressModeOff {
@@ -2332,6 +2353,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
nil,
)
}
switch ingressMode {
case OpenAIWSIngressModePassthrough:
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
}
return s.proxyResponsesWebSocketV2Passthrough(
ctx,
c,
clientConn,
account,
token,
firstClientMessage,
hooks,
wsDecision,
)
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// continue
default:
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"websocket mode only supports ctx_pool/passthrough",
nil,
)
}
}
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)

View File

@@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
select {
case serverErr := <-serverErrCh:
@@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe
require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
upstreamConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPassthroughDialer: captureDialer,
}
account := &Account{
ID: 452,
Name: "openai-ingress-passthrough",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
},
}
serverErrCh := make(chan error, 1)
resultCh := make(chan *OpenAIForwardResult, 1)
hooks := &OpenAIWSIngressHooks{
AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
if turnErr == nil && result != nil {
resultCh <- result
}
},
}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
require.Equal(t, "resp_passthrough_turn_1", gjson.GetBytes(event, "response.id").String())
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 passthrough websocket 结束超时")
}
select {
case result := <-resultCh:
require.Equal(t, "resp_passthrough_turn_1", result.RequestID)
require.True(t, result.OpenAIWSMode)
require.Equal(t, 2, result.Usage.InputTokens)
require.Equal(t, 3, result.Usage.OutputTokens)
case <-time.After(2 * time.Second):
t.Fatal("未收到 passthrough turn 结果回调")
}
require.Equal(t, 1, captureDialer.DialCount(), "passthrough 模式应直接建立上游 websocket")
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -15,6 +15,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
@@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
return event, nil
}
func (c *openAIWSCaptureConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
payload, err := c.ReadMessage(ctx)
if err != nil {
return coderws.MessageText, nil, err
}
return coderws.MessageText, payload, nil
}
func (c *openAIWSCaptureConn) WriteFrame(ctx context.Context, _ coderws.MessageType, payload []byte) error {
return c.WriteJSON(ctx, json.RawMessage(payload))
}
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
_ = ctx
return nil

View File

@@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
switch mode {
case OpenAIWSIngressModeOff:
return openAIWSHTTPDecision("account_mode_off")
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough:
// continue
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// 历史值兼容:按 ctx_pool 处理。
mode = OpenAIWSIngressModeCtxPool
default:
return openAIWSHTTPDecision("account_mode_off")
}

View File

@@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
},
}
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
})
t.Run("off mode routes to http", func(t *testing.T) {
@@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
require.Equal(t, "account_mode_off", decision.Reason)
})
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) {
legacyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
@@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
})
t.Run("passthrough mode routes to ws v2", func(t *testing.T) {
passthroughAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_passthrough", decision.Reason)
})
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)

View File

@@ -0,0 +1,24 @@
package openai_ws_v2
import (
"context"
)
// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
//
// Reference:
// - Project: caddyserver/caddy (Apache-2.0)
// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
// - Files:
// - modules/caddyhttp/reverseproxy/streaming.go
// - modules/caddyhttp/reverseproxy/reverseproxy.go
func runCaddyStyleRelay(
ctx context.Context,
clientConn FrameConn,
upstreamConn FrameConn,
firstClientMessage []byte,
options RelayOptions,
) (RelayResult, *RelayExit) {
return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options)
}

View File

@@ -0,0 +1,23 @@
package openai_ws_v2
import "context"
// EntryInput 是 passthrough v2 数据面的入口参数。
type EntryInput struct {
Ctx context.Context
ClientConn FrameConn
UpstreamConn FrameConn
FirstClientMessage []byte
Options RelayOptions
}
// RunEntry 是 openai_ws_v2 包对外的统一入口。
func RunEntry(input EntryInput) (RelayResult, *RelayExit) {
return runCaddyStyleRelay(
input.Ctx,
input.ClientConn,
input.UpstreamConn,
input.FirstClientMessage,
input.Options,
)
}

View File

@@ -0,0 +1,29 @@
package openai_ws_v2
import (
"sync/atomic"
)
// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
type MetricsSnapshot struct {
SemanticMutationTotal int64 `json:"semantic_mutation_total"`
UsageParseFailureTotal int64 `json:"usage_parse_failure_total"`
}
var (
// passthrough 路径默认不会做语义改写,该计数通常应保持为 0保留用于未来防御性校验
passthroughSemanticMutationTotal atomic.Int64
passthroughUsageParseFailureTotal atomic.Int64
)
func recordUsageParseFailure() {
passthroughUsageParseFailureTotal.Add(1)
}
// SnapshotMetrics 返回当前 passthrough 指标快照。
func SnapshotMetrics() MetricsSnapshot {
return MetricsSnapshot{
SemanticMutationTotal: passthroughSemanticMutationTotal.Load(),
UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(),
}
}

View File

@@ -0,0 +1,807 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
coderws "github.com/coder/websocket"
"github.com/tidwall/gjson"
)
type FrameConn interface {
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
Close() error
}
type Usage struct {
InputTokens int
OutputTokens int
CacheCreationInputTokens int
CacheReadInputTokens int
}
type RelayResult struct {
RequestModel string
Usage Usage
RequestID string
TerminalEventType string
FirstTokenMs *int
Duration time.Duration
ClientToUpstreamFrames int64
UpstreamToClientFrames int64
DroppedDownstreamFrames int64
}
type RelayTurnResult struct {
RequestModel string
Usage Usage
RequestID string
TerminalEventType string
Duration time.Duration
FirstTokenMs *int
}
type RelayExit struct {
Stage string
Err error
WroteDownstream bool
}
type RelayOptions struct {
WriteTimeout time.Duration
IdleTimeout time.Duration
UpstreamDrainTimeout time.Duration
FirstMessageType coderws.MessageType
OnUsageParseFailure func(eventType string, usageRaw string)
OnTurnComplete func(turn RelayTurnResult)
OnTrace func(event RelayTraceEvent)
Now func() time.Time
}
type RelayTraceEvent struct {
Stage string
Direction string
MessageType string
PayloadBytes int
Graceful bool
WroteDownstream bool
Error string
}
type relayState struct {
usage Usage
requestModel string
lastResponseID string
terminalEventType string
firstTokenMs *int
turnTimingByID map[string]*relayTurnTiming
}
type relayExitSignal struct {
stage string
err error
graceful bool
wroteDownstream bool
}
type observedUpstreamEvent struct {
terminal bool
eventType string
responseID string
usage Usage
duration time.Duration
firstToken *int
}
type relayTurnTiming struct {
startAt time.Time
firstTokenMs *int
}
func Relay(
ctx context.Context,
clientConn FrameConn,
upstreamConn FrameConn,
firstClientMessage []byte,
options RelayOptions,
) (RelayResult, *RelayExit) {
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
if clientConn == nil || upstreamConn == nil {
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
}
if ctx == nil {
ctx = context.Background()
}
nowFn := options.Now
if nowFn == nil {
nowFn = time.Now
}
writeTimeout := options.WriteTimeout
if writeTimeout <= 0 {
writeTimeout = 2 * time.Minute
}
drainTimeout := options.UpstreamDrainTimeout
if drainTimeout <= 0 {
drainTimeout = 1200 * time.Millisecond
}
firstMessageType := options.FirstMessageType
if firstMessageType != coderws.MessageBinary {
firstMessageType = coderws.MessageText
}
startAt := nowFn()
state := &relayState{requestModel: result.RequestModel}
onTrace := options.OnTrace
relayCtx, relayCancel := context.WithCancel(ctx)
defer relayCancel()
lastActivity := atomic.Int64{}
lastActivity.Store(nowFn().UnixNano())
markActivity := func() {
lastActivity.Store(nowFn().UnixNano())
}
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
defer cancel()
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
}
writeClient := func(msgType coderws.MessageType, payload []byte) error {
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
defer cancel()
return clientConn.WriteFrame(writeCtx, msgType, payload)
}
clientToUpstreamFrames := &atomic.Int64{}
upstreamToClientFrames := &atomic.Int64{}
droppedDownstreamFrames := &atomic.Int64{}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_start",
PayloadBytes: len(firstClientMessage),
MessageType: relayMessageTypeString(firstMessageType),
})
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
result.Duration = nowFn().Sub(startAt)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
Error: err.Error(),
})
return result, &RelayExit{Stage: "write_upstream", Err: err}
}
clientToUpstreamFrames.Add(1)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_ok",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
})
markActivity()
exitCh := make(chan relayExitSignal, 3)
dropDownstreamWrites := atomic.Bool{}
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
go runUpstreamToClient(
relayCtx,
upstreamConn,
writeClient,
startAt,
nowFn,
state,
options.OnUsageParseFailure,
options.OnTurnComplete,
&dropDownstreamWrites,
upstreamToClientFrames,
droppedDownstreamFrames,
markActivity,
onTrace,
exitCh,
)
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
firstExit := <-exitCh
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "first_exit",
Direction: relayDirectionFromStage(firstExit.stage),
Graceful: firstExit.graceful,
WroteDownstream: firstExit.wroteDownstream,
Error: relayErrorString(firstExit.err),
})
combinedWroteDownstream := firstExit.wroteDownstream
secondExit := relayExitSignal{graceful: true}
hasSecondExit := false
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
if firstExit.stage == "read_client" && firstExit.graceful {
dropDownstreamWrites.Store(true)
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
} else {
relayCancel()
_ = upstreamConn.Close()
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
}
if hasSecondExit {
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "second_exit",
Direction: relayDirectionFromStage(secondExit.stage),
Graceful: secondExit.graceful,
WroteDownstream: secondExit.wroteDownstream,
Error: relayErrorString(secondExit.err),
})
}
relayCancel()
_ = upstreamConn.Close()
enrichResult(&result, state, nowFn().Sub(startAt))
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
if firstExit.stage == "read_client" && firstExit.graceful {
stage := "client_disconnected"
exitErr := firstExit.err
if hasSecondExit && !secondExit.graceful {
stage = secondExit.stage
exitErr = secondExit.err
}
if exitErr == nil {
exitErr = io.EOF
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(exitErr),
})
return result, &RelayExit{
Stage: stage,
Err: exitErr,
WroteDownstream: combinedWroteDownstream,
}
}
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
_ = clientConn.Close()
return result, nil
}
if !firstExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(firstExit.stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(firstExit.err),
})
return result, &RelayExit{
Stage: firstExit.stage,
Err: firstExit.err,
WroteDownstream: combinedWroteDownstream,
}
}
if hasSecondExit && !secondExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(secondExit.stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(secondExit.err),
})
return result, &RelayExit{
Stage: secondExit.stage,
Err: secondExit.err,
WroteDownstream: combinedWroteDownstream,
}
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
_ = clientConn.Close()
return result, nil
}
func runClientToUpstream(
ctx context.Context,
clientConn FrameConn,
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
markActivity func(),
forwardedFrames *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
for {
msgType, payload, err := clientConn.ReadFrame(ctx)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_client_failed",
Direction: "client_to_upstream",
Error: err.Error(),
Graceful: isDisconnectError(err),
})
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
return
}
markActivity()
if err := writeUpstream(msgType, payload); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_upstream_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
Error: err.Error(),
})
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
return
}
if forwardedFrames != nil {
forwardedFrames.Add(1)
}
markActivity()
}
}
func runUpstreamToClient(
ctx context.Context,
upstreamConn FrameConn,
writeClient func(msgType coderws.MessageType, payload []byte) error,
startAt time.Time,
nowFn func() time.Time,
state *relayState,
onUsageParseFailure func(eventType string, usageRaw string),
onTurnComplete func(turn RelayTurnResult),
dropDownstreamWrites *atomic.Bool,
forwardedFrames *atomic.Int64,
droppedFrames *atomic.Int64,
markActivity func(),
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
wroteDownstream := false
for {
msgType, payload, err := upstreamConn.ReadFrame(ctx)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_upstream_failed",
Direction: "upstream_to_client",
Error: err.Error(),
Graceful: isDisconnectError(err),
WroteDownstream: wroteDownstream,
})
exitCh <- relayExitSignal{
stage: "read_upstream",
err: err,
graceful: isDisconnectError(err),
wroteDownstream: wroteDownstream,
}
return
}
markActivity()
observedEvent := observedUpstreamEvent{}
switch msgType {
case coderws.MessageText:
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
case coderws.MessageBinary:
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
}
emitTurnComplete(onTurnComplete, state, observedEvent)
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
if droppedFrames != nil {
droppedFrames.Add(1)
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "drop_downstream_frame",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
})
if observedEvent.terminal {
exitCh <- relayExitSignal{
stage: "drain_terminal",
graceful: true,
wroteDownstream: wroteDownstream,
}
return
}
markActivity()
continue
}
if err := writeClient(msgType, payload); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_client_failed",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
Error: err.Error(),
})
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
return
}
wroteDownstream = true
if forwardedFrames != nil {
forwardedFrames.Add(1)
}
markActivity()
}
}
func runIdleWatchdog(
ctx context.Context,
nowFn func() time.Time,
idleTimeout time.Duration,
lastActivity *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
if idleTimeout <= 0 {
return
}
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
if checkInterval < time.Second {
checkInterval = time.Second
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
last := time.Unix(0, lastActivity.Load())
if nowFn().Sub(last) < idleTimeout {
continue
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "idle_timeout_triggered",
Direction: "watchdog",
Error: context.DeadlineExceeded.Error(),
})
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
return
}
}
}
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
if onTrace == nil {
return
}
onTrace(event)
}
func relayMessageTypeString(msgType coderws.MessageType) string {
switch msgType {
case coderws.MessageText:
return "text"
case coderws.MessageBinary:
return "binary"
default:
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
}
}
func relayDirectionFromStage(stage string) string {
switch stage {
case "read_client", "write_upstream":
return "client_to_upstream"
case "read_upstream", "write_client", "drain_terminal":
return "upstream_to_client"
case "idle_timeout":
return "watchdog"
default:
return ""
}
}
func relayErrorString(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func observeUpstreamMessage(
state *relayState,
message []byte,
startAt time.Time,
nowFn func() time.Time,
onUsageParseFailure func(eventType string, usageRaw string),
) observedUpstreamEvent {
if state == nil || len(message) == 0 {
return observedUpstreamEvent{}
}
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
eventType := strings.TrimSpace(values[0].String())
if eventType == "" {
return observedUpstreamEvent{}
}
responseID := strings.TrimSpace(values[1].String())
if responseID == "" {
responseID = strings.TrimSpace(values[2].String())
}
// 仅 terminal 事件兜底读取顶层 id避免把 event_id 当成 response_id 关联到 turn。
if responseID == "" && isTerminalEvent(eventType) {
responseID = strings.TrimSpace(values[3].String())
}
now := nowFn()
if state.firstTokenMs == nil && isTokenEvent(eventType) {
ms := int(now.Sub(startAt).Milliseconds())
if ms >= 0 {
state.firstTokenMs = &ms
}
}
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
observed := observedUpstreamEvent{
eventType: eventType,
responseID: responseID,
usage: parsedUsage,
}
if responseID != "" {
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
if ms >= 0 {
turnTiming.firstTokenMs = &ms
}
}
}
if !isTerminalEvent(eventType) {
return observed
}
observed.terminal = true
state.terminalEventType = eventType
if responseID != "" {
state.lastResponseID = responseID
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
duration := now.Sub(turnTiming.startAt)
if duration < 0 {
duration = 0
}
observed.duration = duration
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
}
}
return observed
}
func emitTurnComplete(
onTurnComplete func(turn RelayTurnResult),
state *relayState,
observed observedUpstreamEvent,
) {
if onTurnComplete == nil || !observed.terminal {
return
}
responseID := strings.TrimSpace(observed.responseID)
if responseID == "" {
return
}
requestModel := ""
if state != nil {
requestModel = state.requestModel
}
onTurnComplete(RelayTurnResult{
RequestModel: requestModel,
Usage: observed.usage,
RequestID: responseID,
TerminalEventType: observed.eventType,
Duration: observed.duration,
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
})
}
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
if state == nil {
return nil
}
if state.turnTimingByID == nil {
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
}
timing, ok := state.turnTimingByID[responseID]
if !ok || timing == nil || timing.startAt.IsZero() {
timing = &relayTurnTiming{startAt: now}
state.turnTimingByID[responseID] = timing
return timing
}
return timing
}
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
if state == nil || state.turnTimingByID == nil {
return relayTurnTiming{}, false
}
timing, ok := state.turnTimingByID[responseID]
if !ok || timing == nil {
return relayTurnTiming{}, false
}
delete(state.turnTimingByID, responseID)
return *timing, true
}
func openAIWSRelayCloneIntPtr(v *int) *int {
if v == nil {
return nil
}
cloned := *v
return &cloned
}
func parseUsageAndAccumulate(
state *relayState,
message []byte,
eventType string,
onParseFailure func(eventType string, usageRaw string),
) Usage {
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
return Usage{}
}
usageResult := gjson.GetBytes(message, "response.usage")
if !usageResult.Exists() {
return Usage{}
}
usageRaw := strings.TrimSpace(usageResult.Raw)
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
recordUsageParseFailure()
if onParseFailure != nil {
onParseFailure(eventType, usageRaw)
}
return Usage{}
}
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
inputTokens, inputOK := parseUsageIntField(inputResult, true)
outputTokens, outputOK := parseUsageIntField(outputResult, true)
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
if !inputOK || !outputOK || !cachedOK {
recordUsageParseFailure()
if onParseFailure != nil {
onParseFailure(eventType, usageRaw)
}
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
return Usage{}
}
parsedUsage := Usage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheReadInputTokens: cachedTokens,
}
state.usage.InputTokens += parsedUsage.InputTokens
state.usage.OutputTokens += parsedUsage.OutputTokens
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
return parsedUsage
}
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
if !value.Exists() {
return 0, !required
}
if value.Type != gjson.Number {
return 0, false
}
return int(value.Int()), true
}
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
if result == nil {
return
}
result.Duration = duration
if state == nil {
return
}
result.RequestModel = state.requestModel
result.Usage = state.usage
result.RequestID = state.lastResponseID
result.TerminalEventType = state.terminalEventType
result.FirstTokenMs = state.firstTokenMs
}
func isDisconnectError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
return true
}
switch coderws.CloseStatus(err) {
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
return true
}
message := strings.ToLower(strings.TrimSpace(err.Error()))
if message == "" {
return false
}
return strings.Contains(message, "failed to read frame header: eof") ||
strings.Contains(message, "unexpected eof") ||
strings.Contains(message, "use of closed network connection") ||
strings.Contains(message, "connection reset by peer") ||
strings.Contains(message, "broken pipe")
}
func isTerminalEvent(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
}
}
func shouldParseUsage(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed":
return true
default:
return false
}
}
func isTokenEvent(eventType string) bool {
if eventType == "" {
return false
}
switch eventType {
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
return false
}
if strings.Contains(eventType, ".delta") {
return true
}
if strings.HasPrefix(eventType, "response.output_text") {
return true
}
if strings.HasPrefix(eventType, "response.output") {
return true
}
return eventType == "response.completed" || eventType == "response.done"
}
func minDuration(a, b time.Duration) time.Duration {
if a <= 0 {
return b
}
if b <= 0 {
return a
}
if a < b {
return a
}
return b
}
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
if timeout <= 0 {
timeout = 200 * time.Millisecond
}
select {
case sig := <-exitCh:
return sig, true
case <-time.After(timeout):
return relayExitSignal{}, false
}
}

View File

@@ -0,0 +1,432 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestRunEntry_DelegatesRelay(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
result, relayExit := RunEntry(EntryInput{
Ctx: context.Background(),
ClientConn: clientConn,
UpstreamConn: upstreamConn,
FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`),
})
require.Nil(t, relayExit)
require.Equal(t, "resp_entry", result.RequestID)
}
func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
t.Parallel()
t.Run("read client eof", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn(nil, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
nil,
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_client", sig.stage)
require.True(t, sig.graceful)
})
t.Run("write upstream failed", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
func() {},
nil,
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "write_upstream", sig.stage)
require.False(t, sig.graceful)
})
t.Run("forwarded counter and trace callback", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
forwarded := &atomic.Int64{}
traces := make([]RelayTraceEvent, 0, 2)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
forwarded,
func(event RelayTraceEvent) {
traces = append(traces, event)
},
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_client", sig.stage)
require.Equal(t, int64(1), forwarded.Load())
require.NotEmpty(t, traces)
})
}
func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
t.Parallel()
t.Run("read upstream eof", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(false)
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn(nil, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
nil,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_upstream", sig.stage)
require.True(t, sig.graceful)
})
t.Run("write client failed", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(false)
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
nil,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "write_client", sig.stage)
})
t.Run("drop downstream and stop on terminal", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(true)
dropped := &atomic.Int64{}
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
dropped,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "drain_terminal", sig.stage)
require.True(t, sig.graceful)
require.Equal(t, int64(1), dropped.Load())
})
}
func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
lastActivity := &atomic.Int64{}
lastActivity.Store(time.Now().UnixNano())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh)
select {
case <-exitCh:
t.Fatal("unexpected idle timeout signal")
case <-time.After(200 * time.Millisecond):
}
}
func TestHelperFunctionsCoverage(t *testing.T) {
t.Parallel()
require.Equal(t, "text", relayMessageTypeString(coderws.MessageText))
require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary))
require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(")
require.Equal(t, "", relayErrorString(nil))
require.Equal(t, "x", relayErrorString(errors.New("x")))
require.True(t, isDisconnectError(io.EOF))
require.True(t, isDisconnectError(net.ErrClosed))
require.True(t, isDisconnectError(context.Canceled))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway}))
require.True(t, isDisconnectError(errors.New("broken pipe")))
require.False(t, isDisconnectError(errors.New("unrelated")))
require.True(t, isTokenEvent("response.output_text.delta"))
require.True(t, isTokenEvent("response.output_audio.delta"))
require.True(t, isTokenEvent("response.completed"))
require.False(t, isTokenEvent(""))
require.False(t, isTokenEvent("response.created"))
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second))
require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second))
require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second))
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0))
ch := make(chan relayExitSignal, 1)
ch <- relayExitSignal{stage: "ok"}
sig, ok := waitRelayExit(ch, 10*time.Millisecond)
require.True(t, ok)
require.Equal(t, "ok", sig.stage)
ch <- relayExitSignal{stage: "ok2"}
sig, ok = waitRelayExit(ch, 0)
require.True(t, ok)
require.Equal(t, "ok2", sig.stage)
_, ok = waitRelayExit(ch, 10*time.Millisecond)
require.False(t, ok)
n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true)
require.True(t, ok)
require.Equal(t, 3, n)
_, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true)
require.False(t, ok)
n, ok = parseUsageIntField(gjson.Result{}, false)
require.True(t, ok)
require.Equal(t, 0, n)
_, ok = parseUsageIntField(gjson.Result{}, true)
require.False(t, ok)
}
func TestParseUsageAndEnrichCoverage(t *testing.T) {
t.Parallel()
state := &relayState{}
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil)
require.Equal(t, 0, state.usage.InputTokens)
parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`),
"response.completed",
nil,
)
require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage")
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`),
"response.completed",
nil,
)
require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage")
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
require.Equal(t, 2, state.usage.InputTokens)
require.Equal(t, 1, state.usage.OutputTokens)
require.Equal(t, 1, state.usage.CacheReadInputTokens)
result := &RelayResult{}
enrichResult(result, state, 5*time.Millisecond)
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
require.Equal(t, 5*time.Millisecond, result.Duration)
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
require.Equal(t, 2, state.usage.InputTokens)
enrichResult(nil, state, 0)
}
func TestEmitTurnCompleteCoverage(t *testing.T) {
t.Parallel()
// 非 terminal 事件不应触发。
called := 0
emitTurnComplete(func(turn RelayTurnResult) {
called++
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
terminal: false,
eventType: "response.output_text.delta",
responseID: "resp_ignored",
usage: Usage{InputTokens: 1},
})
require.Equal(t, 0, called)
// 缺少 response_id 时不应触发。
emitTurnComplete(func(turn RelayTurnResult) {
called++
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
terminal: true,
eventType: "response.completed",
})
require.Equal(t, 0, called)
// terminal 且 response_id 存在应该触发state=nil 时 model 为空串。
var got RelayTurnResult
emitTurnComplete(func(turn RelayTurnResult) {
called++
got = turn
}, nil, observedUpstreamEvent{
terminal: true,
eventType: "response.completed",
responseID: "resp_emit",
usage: Usage{InputTokens: 2, OutputTokens: 3},
})
require.Equal(t, 1, called)
require.Equal(t, "resp_emit", got.RequestID)
require.Equal(t, "response.completed", got.TerminalEventType)
require.Equal(t, 2, got.Usage.InputTokens)
require.Equal(t, 3, got.Usage.OutputTokens)
require.Equal(t, "", got.RequestModel)
}
func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) {
t.Parallel()
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure}))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd}))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure}))
require.True(t, isDisconnectError(errors.New("connection reset by peer")))
require.False(t, isDisconnectError(errors.New(" ")))
}
func TestIsTokenEventCoverageBranches(t *testing.T) {
t.Parallel()
require.False(t, isTokenEvent("response.in_progress"))
require.False(t, isTokenEvent("response.output_item.added"))
require.True(t, isTokenEvent("response.output_audio.delta"))
require.True(t, isTokenEvent("response.output"))
require.True(t, isTokenEvent("response.done"))
}
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
t.Parallel()
now := time.Unix(100, 0)
// nil state
require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now))
_, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil")
require.False(t, ok)
state := &relayState{}
timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now)
require.NotNil(t, timing)
require.Equal(t, now, timing.startAt)
// 再次获取返回同一条 timing
timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second))
require.NotNil(t, timing2)
require.Equal(t, now, timing2.startAt)
// 删除存在键
deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a")
require.True(t, ok)
require.Equal(t, now, deleted.startAt)
// 删除不存在键
_, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a")
require.False(t, ok)
}
func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) {
t.Parallel()
state := &relayState{requestModel: "gpt-5"}
startAt := time.Unix(0, 0)
now := startAt
nowFn := func() time.Time {
now = now.Add(5 * time.Millisecond)
return now
}
// 非 terminal仅有顶层 id不应把 event id 当成 response_id。
observed := observeUpstreamMessage(
state,
[]byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`),
startAt,
nowFn,
nil,
)
require.False(t, observed.terminal)
require.Equal(t, "", observed.responseID)
// terminal允许兜底用顶层 id用于兼容少数字段变体
observed = observeUpstreamMessage(
state,
[]byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`),
startAt,
nowFn,
nil,
)
require.True(t, observed.terminal)
require.Equal(t, "resp_fallback", observed.responseID)
}

Some files were not shown because too many files have changed in this diff Show More