Compare commits

...

49 Commits

Author SHA1 Message Date
shaw
bab4bb9904 chore: 更新openai、claude使用秘钥教程部分 2026-03-05 18:58:10 +08:00
shaw
33bae6f49b fix: Cache Token拆分为缓存创建和缓存读取 2026-03-05 18:32:17 +08:00
Wesley Liddick
32d619a56b Merge pull request #780 from mt21625457/feat/codex-remote-compact-outcome-logging
feat(openai-handler): support codex remote compact outcome logging
2026-03-05 16:59:02 +08:00
Wesley Liddick
642432cf2a Merge pull request #777 from guoyongchang/feature-schedule-test-support
feat: 支持基于 crontab 的定时账号测试
2026-03-05 16:57:23 +08:00
程序猿MT
61e9598b08 fix(lint): remove redundant context type in compact outcome logger 2026-03-05 16:51:46 +08:00
guoyongchang
d4e34c7514 fix: 修复空结果导致定时测试模态框崩溃的问题
后端返回 null (Go nil slice) 时前端访问 .length 抛出 TypeError,
在 API 层对 listByAccount 和 listResults 加 ?? [] 兜底。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:47:01 +08:00
程序猿MT
bfe7a5e452 test(openai-handler): add codex remote compact outcome coverage 2026-03-05 16:46:14 +08:00
程序猿MT
77d916ffec feat(openai-handler): support codex remote compact outcome logging 2026-03-05 16:46:12 +08:00
guoyongchang
831abf7977 refactor: 移除冗余中间类型和不必要代码
- 移除 ScheduledTestOutcome 中间类型,RunTestBackground 直接返回 *ScheduledTestResult
- 简化 SaveResult 直接接受 *ScheduledTestResult
- 移除 handler 中不必要的 nil 检查
- 移除前端 ScheduledTestsPanel 中多余的 String() 转换

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:37:07 +08:00
guoyongchang
817a491087 simplify: 移除 leader lock,单实例无需分布式锁
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:31:27 +08:00
guoyongchang
9a8dacc514 fix: 修复 golangci-lint depguard 和 gofmt 错误
将 redis leader lock 逻辑从 service 层抽取为 LeaderLocker 接口,
实现移至 repository 层,消除 service 层对 redis 的直接依赖。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:28:48 +08:00
guoyongchang
8adf80d98b fix: wire_gen_test 补充 scheduledTestRunner 参数
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:23:41 +08:00
guoyongchang
62686a6213 revert: 还原 docker-compose.local.yml 的本地测试改动
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:17:33 +08:00
guoyongchang
3a089242f8 feat: 支持基于 crontab 的定时账号测试
每个测试计划绑定一个账号和一个模型,按 cron 表达式定期执行测试,
保存历史结果并在前端账号管理页面中提供完整的增删改查和结果查看功能。

主要变更:
- 新增 scheduled_test_plans / scheduled_test_results 两张表及迁移
- 后端 service 层:CRUD 服务 + 后台 cron runner(每分钟扫描到期计划并发执行)
- RunTestBackground 方法通过 httptest 在内存中执行账号测试并解析 SSE 输出
- Redis leader lock + pg_try_advisory_lock 双重保障多实例部署只执行一次
- REST API:5 个管理端点(计划 CRUD + 结果查询)
- 前端 ScheduledTestsPanel 组件:计划管理、启用开关、内联编辑、结果展开查看
- 中英文 i18n 支持

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:06:05 +08:00
shaw
9d70c38504 fix: 修复claude apikey账号请求时未携带beta=true 查询参数的bug 2026-03-05 15:01:04 +08:00
shaw
aeb464f3ca feat: 模型映射应用 /v1/messages/count_tokens端点 2026-03-05 14:49:28 +08:00
Wesley Liddick
7076717b20 Merge pull request #772 from mt21625457/aicodex2api-main
feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
2026-03-05 13:46:02 +08:00
程序猿MT
c0a4fcea0a Delete docker-compose-aicodex.yml
删除测试 docker compose文件
2026-03-05 13:44:07 +08:00
程序猿MT
aa2b195c86 Delete Caddyfile.dmit
删除测试caddy 配置文件
2026-03-05 13:43:25 +08:00
yangjianbo
1d0872e7ca feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
新增 OpenAI WebSocket v2 passthrough relay 数据面与服务适配层,
支持按账号 ws mode 在 ctx_pool 与 passthrough 间路由。

同步调整前端 OpenAI ws mode 选项为 off/ctx_pool/passthrough,
并补充 i18n 文案与对应单测。

新增 Caddyfile.dmit 与 docker-compose-aicodex.yml 部署配置,
用于宿主机场景下的反向代理与服务编排。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 11:50:58 +08:00
shaw
33988637b5 fix: SMTP测试连接和发送测试邮件返回具体错误信息而非internal error 2026-03-05 10:54:41 +08:00
shaw
d4f6ad7225 feat: 新增apikey的usage查询页面 2026-03-05 10:45:51 +08:00
shaw
078fefed03 fix: 修复账号管理页面容量列显示为0的bug 2026-03-05 09:48:00 +08:00
Wesley Liddick
5b10af85b4 Merge pull request #762 from touwaeriol/fix/dark-theme-open-in-new-tab
fix: add dark theme support for "open in new tab" FAB button
2026-03-05 08:56:28 +08:00
Wesley Liddick
4caf95e5dd Merge pull request #767 from litianc/fix/rewrite-userid-regex-match-account-uuid
fix: extend RewriteUserID regex to match user_id containing account_uuid
2026-03-05 08:56:03 +08:00
litianc
8e1bcf53bb fix: extend RewriteUserID regex to match user_id containing account_uuid
The existing regex only matched the old format where account_uuid is
empty (account__session_). Real Claude Code clients and newer sub2api
generated user_ids use account_{uuid}_session_ which was silently
skipped, causing the original metadata.user_id to leak to upstream
when User-Agent is rewritten by an intermediate gateway.

Closes #766

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-04 23:13:17 +08:00
erio
064f9be7e4 fix: add dark theme support for "open in new tab" FAB button
The backdrop-blur background on the iframe "open in new tab" floating
button was hardcoded to bg-white/80, making it look broken in dark
theme. Added dark:bg-dark-800/80 variant for both PurchaseSubscription
and CustomPage views.
2026-03-04 21:40:40 +08:00
Wesley Liddick
adcfb44cb7 Merge pull request #761 from james-6-23/main
feat: 修复 v0.1.89 OAuth 401 永久锁死账号问题,改用临时不可调度实现自动恢复;增强二次 401 自动升级为错误状态,添加 DB   回退确保生效;管理后台新增临时不可调度状态筛选
2026-03-04 21:11:24 +08:00
kyx236
3d79773ba2 Merge branch 'main' of https://github.com/james-6-23/sub2api 2026-03-04 20:25:39 +08:00
kyx236
6aa8cbbf20 feat: 二次 401 直接升级为错误状态,添加 DB 回退确保生效
账号首次 401 仅临时不可调度,给予 token 刷新窗口;若恢复后再次 401
说明凭证确实失效,直接升级为错误状态以避免反复无效调度。

- 缓存中 reason 为空时从 DB 回退读取,防止升级判断失效
- ClearError 同时清除临时不可调度状态,管理员恢复后重新给予一次机会
- 管理后台账号列表添加"临时不可调度"状态筛选
- 补充 DB 回退场景单元测试
2026-03-04 20:25:15 +08:00
shaw
742e73c9c2 fix: 优化充值/订阅菜单的icon 2026-03-04 17:24:09 +08:00
shaw
f8de2bdedc fix(frontend): settings页面分tab拆分 2026-03-04 16:59:57 +08:00
shaw
59879b7fa7 fix(i18n): replace hardcoded English strings in EmailVerifyView with i18n calls 2026-03-04 15:58:44 +08:00
Wesley Liddick
27abae21b8 Merge pull request #724 from PMExtra/feat/registration-email-domain-whitelist
feat(registration): add email domain whitelist policy
2026-03-04 15:51:51 +08:00
shaw
0819c8a51a refactor: 消除重复的 normalizeAccountIDList,补充 PR#754 新增组件的单元测试
- 删除 account_today_stats_cache.go 中重复的 normalizeAccountIDList,统一使用 id_list_utils.go 的 normalizeInt64IDList
- 新增 snapshot_cache_test.go:覆盖 snapshotCache、buildETagFromAny、parseBoolQueryWithDefault
- 新增 id_list_utils_test.go:覆盖 normalizeInt64IDList、buildAccountTodayStatsBatchCacheKey
- 新增 ops_query_mode_test.go:覆盖 shouldFallbackOpsPreagg、cloneOpsFilterWithMode
2026-03-04 15:22:46 +08:00
Wesley Liddick
9dcd3cd491 Merge pull request #754 from xvhuan/perf/admin-core-large-dataset
perf(admin): 优化后台大数据场景加载性能(仪表盘/用户/账号/Ops)
2026-03-04 15:15:13 +08:00
Wesley Liddick
49767cccd2 Merge pull request #755 from xvhuan/perf/admin-usage-fast-pagination-main
perf(admin-usage): 优化 usage 大表分页,默认避免全量 COUNT(*)
2026-03-04 14:15:57 +08:00
PMExtra
29fb447daa fix(frontend): remove unused variables 2026-03-04 14:12:08 +08:00
xvhuan
f6fe5b552d fix(admin): resolve CI lint and user subscriptions regression 2026-03-04 14:07:17 +08:00
PMExtra
bd0801a887 feat(registration): add email domain whitelist policy 2026-03-04 13:54:18 +08:00
xvhuan
05b1c66aa8 perf(admin-usage): avoid expensive count on large usage_logs pagination 2026-03-04 13:51:27 +08:00
xvhuan
80ae592c23 perf(admin): optimize large-dataset loading for dashboard/users/accounts/ops 2026-03-04 13:45:49 +08:00
shaw
ba6de4c4d4 feat: /keys页面支持表单筛选 2026-03-04 11:29:31 +08:00
shaw
46ea9170cb fix: 修复自定义菜单页面管理员视角菜单不生效问题 2026-03-04 10:44:28 +08:00
shaw
7d318aeefa fix: 恢复check_pnpm_audit_exceptions.py 2026-03-04 10:20:19 +08:00
shaw
0aa3cf677a chore: 清理一些无用的文件 2026-03-04 10:15:42 +08:00
shaw
72961c5858 fix: Anthropic 平台无限流重置时间的 429 不再误标记账号限流 2026-03-04 09:36:24 +08:00
Wesley Liddick
a05711a37a Merge pull request #742 from zqq-nuli/fix/ops-error-detail-upstream-payload
fix(frontend): show real upstream payload in ops error detail modal
2026-03-04 09:04:11 +08:00
zqq61
efc9e1d673 fix(frontend): prefer upstream payload for generic ops error body 2026-03-03 23:45:34 +08:00
167 changed files with 9899 additions and 4945 deletions

105
AGENTS.md
View File

@@ -1,105 +0,0 @@
# Repository Guidelines
## Project Structure & Module Organization
- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output.
- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`.
- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`).
- `openspec/`: Spec-driven change docs (`changes/<id>/{proposal,design,tasks}.md`).
- `tools/`: Utility scripts (security/perf checks).
## Build, Test, and Development Commands
```bash
make build # Build backend + frontend
make test # Backend tests + frontend lint/typecheck
cd backend && make build # Build backend binary
cd backend && make test-unit # Go unit tests
cd backend && make test-integration # Go integration tests
cd backend && make test # go test ./... + golangci-lint
cd frontend && pnpm install --frozen-lockfile
cd frontend && pnpm dev # Vite dev server
cd frontend && pnpm build # Type-check + production build
cd frontend && pnpm test:run # Vitest run
cd frontend && pnpm test:coverage # Vitest + coverage report
python3 tools/secret_scan.py # Secret scan
```
## Coding Style & Naming Conventions
- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`).
- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard).
- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`.
- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`.
## Go & Frontend Development Standards
- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears.
- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency.
- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable.
### Go Performance Rules
- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code.
- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable.
- Every external I/O path must use `context.Context` with explicit timeout/cancel.
- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources.
- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`).
- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`.
- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths.
- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after).
- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out.
- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention.
- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible.
- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary.
- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions.
- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain.
- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible.
- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled.
### Data Access & Cache Rules
- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR.
- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints.
- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths.
- Keep transactions short; never perform external RPC/network calls inside DB transactions.
- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`).
- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks.
- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd.
- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths.
### Frontend Performance Rules
- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components.
- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs.
- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it.
- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed.
- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows).
- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking.
- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review.
- Avoid expensive inline computations in templates; move to cached `computed` selectors.
- Keep state normalized; avoid duplicated derived state across multiple stores/components.
- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import.
- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`.
- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks.
- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes.
### Performance Budget & PR Evidence
- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline.
- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval.
- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table).
- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output).
### Quality Gate
- Any changed code must include new or updated unit tests.
- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules).
- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description.
## Testing Guidelines
- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant.
- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`.
- Enforce unit-test and coverage rules defined in `Quality Gate`.
- Before opening a PR, run `make test` plus targeted tests for touched areas.
## Commit & Pull Request Guidelines
- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`.
- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes.
- For behavior/API changes, add or update `openspec/changes/...` artifacts.
- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR.
## Security & Configuration Tips
- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials.
- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev.

View File

@@ -137,8 +137,6 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。 使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
如果你的服务器是 **Ubuntu 24.04**,建议直接参考:`deploy/ubuntu24-docker-compose-aicodex.md`,其中包含「安装最新版 Docker + docker-compose-aicodex.yml 部署」的完整步骤。
#### 前置条件 #### 前置条件
- Docker 20.10+ - Docker 20.10+

View File

@@ -86,6 +86,7 @@ func provideCleanup(
geminiOAuth *service.GeminiOAuthService, geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService, antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService, openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
) func() { ) func() {
return func() { return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -216,6 +217,12 @@ func provideCleanup(
} }
return nil return nil
}}, }},
{"ScheduledTestRunnerService", func() error {
if scheduledTestRunner != nil {
scheduledTestRunner.Stop()
}
return nil
}},
} }
infraSteps := []cleanupStep{ infraSteps := []cleanupStep{

View File

@@ -195,7 +195,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService) adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler) scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db)
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -225,7 +229,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
@@ -273,6 +278,7 @@ func provideCleanup(
geminiOAuth *service.GeminiOAuthService, geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService, antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService, openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
) func() { ) func() {
return func() { return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -402,6 +408,12 @@ func provideCleanup(
} }
return nil return nil
}}, }},
{"ScheduledTestRunnerService", func() error {
if scheduledTestRunner != nil {
scheduledTestRunner.Stop()
}
return nil
}},
} }
infraSteps := []cleanupStep{ infraSteps := []cleanupStep{

View File

@@ -74,6 +74,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
geminiOAuthSvc, geminiOAuthSvc,
antigravityOAuthSvc, antigravityOAuthSvc,
nil, // openAIGateway nil, // openAIGateway
nil, // scheduledTestRunner
) )
require.NotPanics(t, func() { require.NotPanics(t, func() {

View File

@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
type GatewayOpenAIWSConfig struct { type GatewayOpenAIWSConfig struct {
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false关闭时保持 legacy 行为) // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false关闭时保持 legacy 行为)
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
// IngressModeDefault: ingress 默认模式off/shared/dedicated // IngressModeDefault: ingress 默认模式off/ctx_pool/passthrough
IngressModeDefault string `mapstructure:"ingress_mode_default"` IngressModeDefault string `mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true // Enabled: 全局总开关(默认 true
Enabled bool `mapstructure:"enabled"` Enabled bool `mapstructure:"enabled"`
@@ -1227,7 +1227,7 @@ func setDefaults() {
// Ops (vNext) // Ops (vNext)
viper.SetDefault("ops.enabled", true) viper.SetDefault("ops.enabled", true)
viper.SetDefault("ops.use_preaggregated_tables", false) viper.SetDefault("ops.use_preaggregated_tables", true)
viper.SetDefault("ops.cleanup.enabled", true) viper.SetDefault("ops.cleanup.enabled", true)
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *") viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
// Retention days: vNext defaults to 30 days across ops datasets. // Retention days: vNext defaults to 30 days across ops datasets.
@@ -1335,7 +1335,7 @@ func setDefaults() {
// OpenAI Responses WebSocket默认开启可通过 force_http 紧急回滚) // OpenAI Responses WebSocket默认开启可通过 force_http 紧急回滚)
viper.SetDefault("gateway.openai_ws.enabled", true) viper.SetDefault("gateway.openai_ws.enabled", true)
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false) viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared") viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
viper.SetDefault("gateway.openai_ws.oauth_enabled", true) viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
viper.SetDefault("gateway.openai_ws.apikey_enabled", true) viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
viper.SetDefault("gateway.openai_ws.force_http", false) viper.SetDefault("gateway.openai_ws.force_http", false)
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
} }
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" { if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
switch mode { switch mode {
case "off", "shared", "dedicated": case "off", "ctx_pool", "passthrough":
case "shared", "dedicated":
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
default: default:
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated") return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
} }
} }
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" { if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {

View File

@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false") t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
} }
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" { if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared") t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
} }
} }
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
wantErr: "gateway.openai_ws.store_disabled_conn_mode", wantErr: "gateway.openai_ws.store_disabled_conn_mode",
}, },
{ {
name: "ingress_mode_default 必须为 off|shared|dedicated", name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" }, mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
wantErr: "gateway.openai_ws.ingress_mode_default", wantErr: "gateway.openai_ws.ingress_mode_default",
}, },

View File

@@ -217,6 +217,7 @@ func (h *AccountHandler) List(c *gin.Context) {
if len(search) > 100 { if len(search) > 100 {
search = search[:100] search = search[:100]
} }
lite := parseBoolQueryWithDefault(c.Query("lite"), false)
var groupID int64 var groupID int64
if groupIDStr := c.Query("group"); groupIDStr != "" { if groupIDStr := c.Query("group"); groupIDStr != "" {
@@ -235,10 +236,16 @@ func (h *AccountHandler) List(c *gin.Context) {
accountIDs[i] = acc.ID accountIDs[i] = acc.ID
} }
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs) concurrencyCounts := make(map[int64]int)
if err != nil { var windowCosts map[int64]float64
// Log error but don't fail the request, just use 0 for all var activeSessions map[int64]int
concurrencyCounts = make(map[int64]int) var rpmCounts map[int64]int
// 始终获取并发数Redis ZCARD极低开销
if h.concurrencyService != nil {
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
concurrencyCounts = cc
}
} }
// 识别需要查询窗口费用、会话数和 RPM 的账号Anthropic OAuth/SetupToken 且启用了相应功能) // 识别需要查询窗口费用、会话数和 RPM 的账号Anthropic OAuth/SetupToken 且启用了相应功能)
@@ -262,12 +269,7 @@ func (h *AccountHandler) List(c *gin.Context) {
} }
} }
// 并行获取窗口费用、活跃会话数和 RPM 计数 // 始终获取 RPM 计数Redis GET极低开销
var windowCosts map[int64]float64
var activeSessions map[int64]int
var rpmCounts map[int64]int
// 获取 RPM 计数(批量查询)
if len(rpmAccountIDs) > 0 && h.rpmCache != nil { if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs) rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
if rpmCounts == nil { if rpmCounts == nil {
@@ -275,7 +277,7 @@ func (h *AccountHandler) List(c *gin.Context) {
} }
} }
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置 // 始终获取活跃会话数(Redis ZCARD低开销
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts) activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil { if activeSessions == nil {
@@ -283,8 +285,8 @@ func (h *AccountHandler) List(c *gin.Context) {
} }
} }
// 获取窗口费用(并行查询 // 仅非 lite 模式获取窗口费用PostgreSQL 聚合查询,高开销
if len(windowCostAccountIDs) > 0 { if !lite && len(windowCostAccountIDs) > 0 {
windowCosts = make(map[int64]float64) windowCosts = make(map[int64]float64)
var mu sync.Mutex var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context()) g, gctx := errgroup.WithContext(c.Request.Context())
@@ -344,7 +346,7 @@ func (h *AccountHandler) List(c *gin.Context) {
result[i] = item result[i] = item
} }
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search) etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite)
if etag != "" { if etag != "" {
c.Header("ETag", etag) c.Header("ETag", etag)
c.Header("Vary", "If-None-Match") c.Header("Vary", "If-None-Match")
@@ -362,6 +364,7 @@ func buildAccountsListETag(
total int64, total int64,
page, pageSize int, page, pageSize int,
platform, accountType, status, search string, platform, accountType, status, search string,
lite bool,
) string { ) string {
payload := struct { payload := struct {
Total int64 `json:"total"` Total int64 `json:"total"`
@@ -371,6 +374,7 @@ func buildAccountsListETag(
AccountType string `json:"type"` AccountType string `json:"type"`
Status string `json:"status"` Status string `json:"status"`
Search string `json:"search"` Search string `json:"search"`
Lite bool `json:"lite"`
Items []AccountWithConcurrency `json:"items"` Items []AccountWithConcurrency `json:"items"`
}{ }{
Total: total, Total: total,
@@ -380,6 +384,7 @@ func buildAccountsListETag(
AccountType: accountType, AccountType: accountType,
Status: status, Status: status,
Search: search, Search: search,
Lite: lite,
Items: items, Items: items,
} }
raw, err := json.Marshal(payload) raw, err := json.Marshal(payload)
@@ -1398,18 +1403,41 @@ func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
return return
} }
if len(req.AccountIDs) == 0 { accountIDs := normalizeInt64IDList(req.AccountIDs)
if len(accountIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}}) response.Success(c, gin.H{"stats": map[string]any{}})
return return
} }
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs) cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs)
if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, gin.H{"stats": stats}) payload := gin.H{"stats": stats}
cached := accountTodayStatsBatchCache.Set(cacheKey, payload)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
} }
// SetSchedulableRequest represents the request body for setting schedulable status // SetSchedulableRequest represents the request body for setting schedulable status

View File

@@ -0,0 +1,25 @@
package admin
import (
"strconv"
"strings"
"time"
)
var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second)
func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string {
if len(accountIDs) == 0 {
return "accounts_today_stats_empty"
}
var b strings.Builder
b.Grow(len(accountIDs) * 6)
_, _ = b.WriteString("accounts_today_stats:")
for i, id := range accountIDs {
if i > 0 {
_ = b.WriteByte(',')
}
_, _ = b.WriteString(strconv.FormatInt(id, 10))
}
return b.String()
}

View File

@@ -1,6 +1,7 @@
package admin package admin
import ( import (
"encoding/json"
"errors" "errors"
"strconv" "strconv"
"strings" "strings"
@@ -460,6 +461,9 @@ type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"` UserIDs []int64 `json:"user_ids" binding:"required"`
} }
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
// GetBatchUsersUsage handles getting usage stats for multiple users // GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage // POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
@@ -469,18 +473,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return return
} }
if len(req.UserIDs) == 0 { userIDs := normalizeInt64IDList(req.UserIDs)
if len(userIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}}) response.Success(c, gin.H{"stats": map[string]any{}})
return return
} }
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{}) keyRaw, _ := json.Marshal(struct {
UserIDs []int64 `json:"user_ids"`
}{
UserIDs: userIDs,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{})
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get user usage stats") response.Error(c, 500, "Failed to get user usage stats")
return return
} }
response.Success(c, gin.H{"stats": stats}) payload := gin.H{"stats": stats}
dashboardBatchUsersUsageCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
} }
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats // BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
@@ -497,16 +517,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return return
} }
if len(req.APIKeyIDs) == 0 { apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs)
if len(apiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}}) response.Success(c, gin.H{"stats": map[string]any{}})
return return
} }
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{}) keyRaw, _ := json.Marshal(struct {
APIKeyIDs []int64 `json:"api_key_ids"`
}{
APIKeyIDs: apiKeyIDs,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{})
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get API key usage stats") response.Error(c, 500, "Failed to get API key usage stats")
return return
} }
response.Success(c, gin.H{"stats": stats}) payload := gin.H{"stats": stats}
dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
} }

View File

@@ -0,0 +1,292 @@
package admin
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
type dashboardSnapshotV2Stats struct {
usagestats.DashboardStats
Uptime int64 `json:"uptime"`
}
type dashboardSnapshotV2Response struct {
GeneratedAt string `json:"generated_at"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
Granularity string `json:"granularity"`
Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"`
Trend []usagestats.TrendDataPoint `json:"trend,omitempty"`
Models []usagestats.ModelStat `json:"models,omitempty"`
Groups []usagestats.GroupStat `json:"groups,omitempty"`
UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"`
}
type dashboardSnapshotV2Filters struct {
UserID int64
APIKeyID int64
AccountID int64
GroupID int64
Model string
RequestType *int16
Stream *bool
BillingType *int8
}
type dashboardSnapshotV2CacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Model string `json:"model"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
IncludeStats bool `json:"include_stats"`
IncludeTrend bool `json:"include_trend"`
IncludeModels bool `json:"include_models"`
IncludeGroups bool `json:"include_groups"`
IncludeUsersTrend bool `json:"include_users_trend"`
UsersTrendLimit int `json:"users_trend_limit"`
}
func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day"))
if granularity != "hour" {
granularity = "day"
}
includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true)
includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true)
includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true)
includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false)
includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false)
usersTrendLimit := 12
if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" {
if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 {
usersTrendLimit = parsed
}
}
filters, err := parseDashboardSnapshotV2Filters(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
UserID: filters.UserID,
APIKeyID: filters.APIKeyID,
AccountID: filters.AccountID,
GroupID: filters.GroupID,
Model: filters.Model,
RequestType: filters.RequestType,
Stream: filters.Stream,
BillingType: filters.BillingType,
IncludeStats: includeStats,
IncludeTrend: includeTrend,
IncludeModels: includeModels,
IncludeGroups: includeGroups,
IncludeUsersTrend: includeUsersTrend,
UsersTrendLimit: usersTrendLimit,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
resp := &dashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
StartDate: startTime.Format("2006-01-02"),
EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"),
Granularity: granularity,
}
if includeStats {
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
}
resp.Stats = &dashboardSnapshotV2Stats{
DashboardStats: *stats,
Uptime: int64(time.Since(h.startTime).Seconds()),
}
}
if includeTrend {
trend, err := h.dashboardService.GetUsageTrendWithFilters(
c.Request.Context(),
startTime,
endTime,
granularity,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.Model,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
}
resp.Trend = trend
}
if includeModels {
models, err := h.dashboardService.GetModelStatsWithFilters(
c.Request.Context(),
startTime,
endTime,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
}
resp.Models = models
}
if includeGroups {
groups, err := h.dashboardService.GetGroupStatsWithFilters(
c.Request.Context(),
startTime,
endTime,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
}
resp.Groups = groups
}
if includeUsersTrend {
usersTrend, err := h.dashboardService.GetUserUsageTrend(
c.Request.Context(),
startTime,
endTime,
granularity,
usersTrendLimit,
)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
}
resp.UsersTrend = usersTrend
}
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
}
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
filters := &dashboardSnapshotV2Filters{
Model: strings.TrimSpace(c.Query("model")),
}
if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.UserID = id
}
if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.APIKeyID = id
}
if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" {
id, err := strconv.ParseInt(accountIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.AccountID = id
}
if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" {
id, err := strconv.ParseInt(groupIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.GroupID = id
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
return nil, err
}
value := int16(parsed)
filters.RequestType = &value
} else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" {
streamVal, err := strconv.ParseBool(streamStr)
if err != nil {
return nil, err
}
filters.Stream = &streamVal
}
if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" {
v, err := strconv.ParseInt(billingTypeStr, 10, 8)
if err != nil {
return nil, err
}
bt := int8(v)
filters.BillingType = &bt
}
return filters, nil
}

View File

@@ -0,0 +1,25 @@
package admin
import "sort"
func normalizeInt64IDList(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
out := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
return out
}

View File

@@ -0,0 +1,57 @@
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNormalizeInt64IDList(t *testing.T) {
tests := []struct {
name string
in []int64
want []int64
}{
{"nil input", nil, nil},
{"empty input", []int64{}, nil},
{"single element", []int64{5}, []int64{5}},
{"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}},
{"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}},
{"zero filtered", []int64{0, 1, 2}, []int64{1, 2}},
{"negative filtered", []int64{-5, -1, 3}, []int64{3}},
{"all invalid", []int64{0, -1, -2}, []int64{}},
{"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := normalizeInt64IDList(tc.in)
if tc.want == nil {
require.Nil(t, got)
} else {
require.Equal(t, tc.want, got)
}
})
}
}
func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) {
tests := []struct {
name string
ids []int64
want string
}{
{"empty", nil, "accounts_today_stats_empty"},
{"single", []int64{42}, "accounts_today_stats:42"},
{"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := buildAccountTodayStatsBatchCacheKey(tc.ids)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -0,0 +1,145 @@
package admin
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
)
var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
type opsDashboardSnapshotV2Response struct {
GeneratedAt string `json:"generated_at"`
Overview *service.OpsDashboardOverview `json:"overview"`
ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"`
ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"`
}
type opsDashboardSnapshotV2CacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
QueryMode service.OpsQueryMode `json:"mode"`
BucketSecond int `json:"bucket_second"`
}
// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request.
// GET /api/v1/admin/ops/dashboard/snapshot-v2
func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Platform: filter.Platform,
GroupID: filter.GroupID,
QueryMode: filter.QueryMode,
BucketSecond: bucketSeconds,
})
cacheKey := string(keyRaw)
if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
var (
overview *service.OpsDashboardOverview
trend *service.OpsThroughputTrendResponse
errTrend *service.OpsErrorTrendResponse
)
g, gctx := errgroup.WithContext(c.Request.Context())
g.Go(func() error {
f := *filter
result, err := h.opsService.GetDashboardOverview(gctx, &f)
if err != nil {
return err
}
overview = result
return nil
})
g.Go(func() error {
f := *filter
result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds)
if err != nil {
return err
}
trend = result
return nil
})
g.Go(func() error {
f := *filter
result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds)
if err != nil {
return err
}
errTrend = result
return nil
})
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
resp := &opsDashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
Overview: overview,
ThroughputTrend: trend,
ErrorTrend: errTrend,
}
cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
}

View File

@@ -0,0 +1,155 @@
package admin
import (
"net/http"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ScheduledTestHandler handles admin scheduled-test-plan management.
type ScheduledTestHandler struct {
scheduledTestSvc *service.ScheduledTestService
}
// NewScheduledTestHandler creates a new ScheduledTestHandler.
func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler {
return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc}
}
type createScheduledTestPlanRequest struct {
AccountID int64 `json:"account_id" binding:"required"`
ModelID string `json:"model_id"`
CronExpression string `json:"cron_expression" binding:"required"`
Enabled *bool `json:"enabled"`
MaxResults int `json:"max_results"`
}
type updateScheduledTestPlanRequest struct {
ModelID string `json:"model_id"`
CronExpression string `json:"cron_expression"`
Enabled *bool `json:"enabled"`
MaxResults int `json:"max_results"`
}
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid account id")
return
}
plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, plans)
}
// Create POST /admin/scheduled-test-plans
func (h *ScheduledTestHandler) Create(c *gin.Context) {
var req createScheduledTestPlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
plan := &service.ScheduledTestPlan{
AccountID: req.AccountID,
ModelID: req.ModelID,
CronExpression: req.CronExpression,
Enabled: true,
MaxResults: req.MaxResults,
}
if req.Enabled != nil {
plan.Enabled = *req.Enabled
}
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
if err != nil {
response.BadRequest(c, err.Error())
return
}
c.JSON(http.StatusOK, created)
}
// Update PUT /admin/scheduled-test-plans/:id
func (h *ScheduledTestHandler) Update(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID)
if err != nil {
response.NotFound(c, "plan not found")
return
}
var req updateScheduledTestPlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if req.ModelID != "" {
existing.ModelID = req.ModelID
}
if req.CronExpression != "" {
existing.CronExpression = req.CronExpression
}
if req.Enabled != nil {
existing.Enabled = *req.Enabled
}
if req.MaxResults > 0 {
existing.MaxResults = req.MaxResults
}
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
if err != nil {
response.BadRequest(c, err.Error())
return
}
c.JSON(http.StatusOK, updated)
}
// Delete DELETE /admin/scheduled-test-plans/:id
func (h *ScheduledTestHandler) Delete(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
}
// ListResults GET /admin/scheduled-test-plans/:id/results
func (h *ScheduledTestHandler) ListResults(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
limit := 50
if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 {
limit = l
}
results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit)
if err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, results)
}

View File

@@ -77,6 +77,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled, InvitationCodeEnabled: settings.InvitationCodeEnabled,
@@ -130,12 +131,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求 // UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct { type UpdateSettingsRequest struct {
// 注册设置 // 注册设置
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置 // 邮件服务设置
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
@@ -426,50 +428,51 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
settings := &service.SystemSettings{ settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled, RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled, RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PasswordResetEnabled: req.PasswordResetEnabled, PromoCodeEnabled: req.PromoCodeEnabled,
InvitationCodeEnabled: req.InvitationCodeEnabled, PasswordResetEnabled: req.PasswordResetEnabled,
TotpEnabled: req.TotpEnabled, InvitationCodeEnabled: req.InvitationCodeEnabled,
SMTPHost: req.SMTPHost, TotpEnabled: req.TotpEnabled,
SMTPPort: req.SMTPPort, SMTPHost: req.SMTPHost,
SMTPUsername: req.SMTPUsername, SMTPPort: req.SMTPPort,
SMTPPassword: req.SMTPPassword, SMTPUsername: req.SMTPUsername,
SMTPFrom: req.SMTPFrom, SMTPPassword: req.SMTPPassword,
SMTPFromName: req.SMTPFromName, SMTPFrom: req.SMTPFrom,
SMTPUseTLS: req.SMTPUseTLS, SMTPFromName: req.SMTPFromName,
TurnstileEnabled: req.TurnstileEnabled, SMTPUseTLS: req.SMTPUseTLS,
TurnstileSiteKey: req.TurnstileSiteKey, TurnstileEnabled: req.TurnstileEnabled,
TurnstileSecretKey: req.TurnstileSecretKey, TurnstileSiteKey: req.TurnstileSiteKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectClientID: req.LinuxDoConnectClientID, LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
SiteName: req.SiteName, LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteLogo: req.SiteLogo, SiteName: req.SiteName,
SiteSubtitle: req.SiteSubtitle, SiteLogo: req.SiteLogo,
APIBaseURL: req.APIBaseURL, SiteSubtitle: req.SiteSubtitle,
ContactInfo: req.ContactInfo, APIBaseURL: req.APIBaseURL,
DocURL: req.DocURL, ContactInfo: req.ContactInfo,
HomeContent: req.HomeContent, DocURL: req.DocURL,
HideCcsImportButton: req.HideCcsImportButton, HomeContent: req.HomeContent,
PurchaseSubscriptionEnabled: purchaseEnabled, HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionURL: purchaseURL, PurchaseSubscriptionEnabled: purchaseEnabled,
SoraClientEnabled: req.SoraClientEnabled, PurchaseSubscriptionURL: purchaseURL,
CustomMenuItems: customMenuJSON, SoraClientEnabled: req.SoraClientEnabled,
DefaultConcurrency: req.DefaultConcurrency, CustomMenuItems: customMenuJSON,
DefaultBalance: req.DefaultBalance, DefaultConcurrency: req.DefaultConcurrency,
DefaultSubscriptions: defaultSubscriptions, DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback, DefaultSubscriptions: defaultSubscriptions,
FallbackModelAnthropic: req.FallbackModelAnthropic, EnableModelFallback: req.EnableModelFallback,
FallbackModelOpenAI: req.FallbackModelOpenAI, FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelGemini: req.FallbackModelGemini, FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelAntigravity: req.FallbackModelAntigravity, FallbackModelGemini: req.FallbackModelGemini,
EnableIdentityPatch: req.EnableIdentityPatch, FallbackModelAntigravity: req.FallbackModelAntigravity,
IdentityPatchPrompt: req.IdentityPatchPrompt, EnableIdentityPatch: req.EnableIdentityPatch,
MinClaudeCodeVersion: req.MinClaudeCodeVersion, IdentityPatchPrompt: req.IdentityPatchPrompt,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
OpsMonitoringEnabled: func() bool { OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil { if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled return *req.OpsMonitoringEnabled
@@ -520,6 +523,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled, RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled, PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled, PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
@@ -598,6 +602,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled { if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled") changed = append(changed, "email_verify_enabled")
} }
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled { if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled") changed = append(changed, "password_reset_enabled")
} }
@@ -747,6 +754,18 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
return normalized return normalized
} }
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
if len(a) != len(b) { if len(a) != len(b) {
return false return false
@@ -800,7 +819,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
err := h.emailService.TestSMTPConnectionWithConfig(config) err := h.emailService.TestSMTPConnectionWithConfig(config)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.BadRequest(c, "SMTP connection test failed: "+err.Error())
return return
} }
@@ -886,7 +905,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
` `
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil { if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
response.ErrorFrom(c, err) response.BadRequest(c, "Failed to send test email: "+err.Error())
return return
} }

View File

@@ -0,0 +1,95 @@
package admin
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"strings"
"sync"
"time"
)
type snapshotCacheEntry struct {
ETag string
Payload any
ExpiresAt time.Time
}
type snapshotCache struct {
mu sync.RWMutex
ttl time.Duration
items map[string]snapshotCacheEntry
}
func newSnapshotCache(ttl time.Duration) *snapshotCache {
if ttl <= 0 {
ttl = 30 * time.Second
}
return &snapshotCache{
ttl: ttl,
items: make(map[string]snapshotCacheEntry),
}
}
func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) {
if c == nil || key == "" {
return snapshotCacheEntry{}, false
}
now := time.Now()
c.mu.RLock()
entry, ok := c.items[key]
c.mu.RUnlock()
if !ok {
return snapshotCacheEntry{}, false
}
if now.After(entry.ExpiresAt) {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
return snapshotCacheEntry{}, false
}
return entry, true
}
func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
if c == nil {
return snapshotCacheEntry{}
}
entry := snapshotCacheEntry{
ETag: buildETagFromAny(payload),
Payload: payload,
ExpiresAt: time.Now().Add(c.ttl),
}
if key == "" {
return entry
}
c.mu.Lock()
c.items[key] = entry
c.mu.Unlock()
return entry
}
func buildETagFromAny(payload any) string {
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
sum := sha256.Sum256(raw)
return "\"" + hex.EncodeToString(sum[:]) + "\""
}
func parseBoolQueryWithDefault(raw string, def bool) bool {
value := strings.TrimSpace(strings.ToLower(raw))
if value == "" {
return def
}
switch value {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return def
}
}

View File

@@ -0,0 +1,128 @@
//go:build unit
package admin
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSnapshotCache_SetAndGet(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
entry := c.Set("key1", map[string]string{"hello": "world"})
require.NotEmpty(t, entry.ETag)
require.NotNil(t, entry.Payload)
got, ok := c.Get("key1")
require.True(t, ok)
require.Equal(t, entry.ETag, got.ETag)
}
func TestSnapshotCache_Expiration(t *testing.T) {
c := newSnapshotCache(1 * time.Millisecond)
c.Set("key1", "value")
time.Sleep(5 * time.Millisecond)
_, ok := c.Get("key1")
require.False(t, ok, "expired entry should not be returned")
}
func TestSnapshotCache_GetEmptyKey(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
_, ok := c.Get("")
require.False(t, ok)
}
func TestSnapshotCache_GetMiss(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
_, ok := c.Get("nonexistent")
require.False(t, ok)
}
func TestSnapshotCache_NilReceiver(t *testing.T) {
var c *snapshotCache
_, ok := c.Get("key")
require.False(t, ok)
entry := c.Set("key", "value")
require.Empty(t, entry.ETag)
}
func TestSnapshotCache_SetEmptyKey(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
// Set with empty key should return entry but not store it
entry := c.Set("", "value")
require.NotEmpty(t, entry.ETag)
_, ok := c.Get("")
require.False(t, ok)
}
func TestSnapshotCache_DefaultTTL(t *testing.T) {
c := newSnapshotCache(0)
require.Equal(t, 30*time.Second, c.ttl)
c2 := newSnapshotCache(-1 * time.Second)
require.Equal(t, 30*time.Second, c2.ttl)
}
func TestSnapshotCache_ETagDeterministic(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
payload := map[string]int{"a": 1, "b": 2}
entry1 := c.Set("k1", payload)
entry2 := c.Set("k2", payload)
require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag")
}
func TestSnapshotCache_ETagFormat(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
entry := c.Set("k", "test")
// ETag should be quoted hex string: "abcdef..."
require.True(t, len(entry.ETag) > 2)
require.Equal(t, byte('"'), entry.ETag[0])
require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1])
}
func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
// channels are not JSON-serializable
etag := buildETagFromAny(make(chan int))
require.Empty(t, etag)
}
func TestParseBoolQueryWithDefault(t *testing.T) {
tests := []struct {
name string
raw string
def bool
want bool
}{
{"empty returns default true", "", true, true},
{"empty returns default false", "", false, false},
{"1", "1", false, true},
{"true", "true", false, true},
{"TRUE", "TRUE", false, true},
{"yes", "yes", false, true},
{"on", "on", false, true},
{"0", "0", true, false},
{"false", "false", true, false},
{"FALSE", "FALSE", true, false},
{"no", "no", true, false},
{"off", "off", true, false},
{"whitespace trimmed", " true ", false, true},
{"unknown returns default true", "maybe", true, true},
{"unknown returns default false", "maybe", false, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := parseBoolQueryWithDefault(tc.raw, tc.def)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -61,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct {
// GET /api/v1/admin/usage // GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) { func (h *UsageHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
exactTotal := false
if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" {
parsed, err := strconv.ParseBool(exactTotalRaw)
if err != nil {
response.BadRequest(c, "Invalid exact_total value, use true or false")
return
}
exactTotal = parsed
}
// Parse filters // Parse filters
var userID, apiKeyID, accountID, groupID int64 var userID, apiKeyID, accountID, groupID int64
@@ -167,6 +176,7 @@ func (h *UsageHandler) List(c *gin.Context) {
BillingType: billingType, BillingType: billingType,
StartTime: startTime, StartTime: startTime,
EndTime: endTime, EndTime: endTime,
ExactTotal: exactTotal,
} }
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)

View File

@@ -80,6 +80,29 @@ func TestAdminUsageListInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code) require.Equal(t, http.StatusBadRequest, rec.Code)
} }
func TestAdminUsageListExactTotalTrue(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.True(t, repo.listFilters.ExactTotal)
}
func TestAdminUsageListInvalidExactTotal(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageStatsRequestTypePriority(t *testing.T) { func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
repo := &adminUsageRepoCapture{} repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo) router := newAdminUsageRequestTypeTestRouter(repo)

View File

@@ -1,7 +1,9 @@
package admin package admin
import ( import (
"encoding/json"
"strconv" "strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct {
Attributes map[int64]map[int64]string `json:"attributes"` Attributes map[int64]map[int64]string `json:"attributes"`
} }
var userAttributesBatchCache = newSnapshotCache(30 * time.Second)
// AttributeDefinitionResponse represents attribute definition response // AttributeDefinitionResponse represents attribute definition response
type AttributeDefinitionResponse struct { type AttributeDefinitionResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
@@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
return return
} }
if len(req.UserIDs) == 0 { userIDs := normalizeInt64IDList(req.UserIDs)
if len(userIDs) == 0 {
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}}) response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
return return
} }
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs) keyRaw, _ := json.Marshal(struct {
UserIDs []int64 `json:"user_ids"`
}{
UserIDs: userIDs,
})
cacheKey := string(keyRaw)
if cached, ok := userAttributesBatchCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, BatchUserAttributesResponse{Attributes: attrs}) payload := BatchUserAttributesResponse{Attributes: attrs}
userAttributesBatchCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
} }

View File

@@ -91,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) {
Search: search, Search: search,
Attributes: parseAttributeFilters(c), Attributes: parseAttributeFilters(c),
} }
if raw, ok := c.GetQuery("include_subscriptions"); ok {
includeSubscriptions := parseBoolQueryWithDefault(raw, true)
filters.IncludeSubscriptions = &includeSubscriptions
}
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters) users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
if err != nil { if err != nil {

View File

@@ -4,6 +4,7 @@ package handler
import ( import (
"context" "context"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -73,7 +74,23 @@ func (h *APIKeyHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params) // Parse filter parameters
var filters service.APIKeyListFilters
if search := strings.TrimSpace(c.Query("search")); search != "" {
if len(search) > 100 {
search = search[:100]
}
filters.Search = search
}
filters.Status = c.Query("status")
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
gid, err := strconv.ParseInt(groupIDStr, 10, 64)
if err == nil {
filters.GroupID = &gid
}
}
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params, filters)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -17,13 +17,14 @@ type CustomMenuItem struct {
// SystemSettings represents the admin settings API response payload. // SystemSettings represents the admin settings API response payload.
type SystemSettings struct { type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
@@ -88,28 +89,29 @@ type DefaultSubscriptionSetting struct {
} }
type PublicSettings struct { type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"` TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileEnabled bool `json:"turnstile_enabled"`
SiteName string `json:"site_name"` TurnstileSiteKey string `json:"turnstile_site_key"`
SiteLogo string `json:"site_logo"` SiteName string `json:"site_name"`
SiteSubtitle string `json:"site_subtitle"` SiteLogo string `json:"site_logo"`
APIBaseURL string `json:"api_base_url"` SiteSubtitle string `json:"site_subtitle"`
ContactInfo string `json:"contact_info"` APIBaseURL string `json:"api_base_url"`
DocURL string `json:"doc_url"` ContactInfo string `json:"contact_info"`
HomeContent string `json:"home_content"` DocURL string `json:"doc_url"`
HideCcsImportButton bool `json:"hide_ccs_import_button"` HomeContent string `json:"home_content"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
SoraClientEnabled bool `json:"sora_client_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"` SoraClientEnabled bool `json:"sora_client_enabled"`
Version string `json:"version"`
} }
// SoraS3Settings Sora S3 存储配置 DTO响应用不含敏感字段 // SoraS3Settings Sora S3 存储配置 DTO响应用不含敏感字段

View File

@@ -27,6 +27,7 @@ type AdminHandlers struct {
UserAttribute *admin.UserAttributeHandler UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler ErrorPassthrough *admin.ErrorPassthroughHandler
APIKey *admin.AdminAPIKeyHandler APIKey *admin.AdminAPIKeyHandler
ScheduledTest *admin.ScheduledTestHandler
} }
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers

View File

@@ -0,0 +1,192 @@
package handler
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
var handlerStructuredLogCaptureMu sync.Mutex
type handlerInMemoryLogSink struct {
mu sync.Mutex
events []*logger.LogEvent
}
func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
if event == nil {
return
}
cloned := *event
if event.Fields != nil {
cloned.Fields = make(map[string]any, len(event.Fields))
for k, v := range event.Fields {
cloned.Fields[k] = v
}
}
s.mu.Lock()
s.events = append(s.events, &cloned)
s.mu.Unlock()
}
func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool {
s.mu.Lock()
defer s.mu.Unlock()
wantLevel := strings.ToLower(strings.TrimSpace(level))
for _, ev := range s.events {
if ev == nil {
continue
}
if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel {
return true
}
}
return false
}
func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool {
s.mu.Lock()
defer s.mu.Unlock()
for _, ev := range s.events {
if ev == nil || ev.Fields == nil {
continue
}
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
return true
}
}
return false
}
func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) {
t.Helper()
handlerStructuredLogCaptureMu.Lock()
err := logger.Init(logger.InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: logger.SamplingOptions{Enabled: false},
})
require.NoError(t, err)
sink := &handlerInMemoryLogSink{}
logger.SetSink(sink)
return sink, func() {
logger.SetSink(nil)
handlerStructuredLogCaptureMu.Unlock()
}
}
func TestIsOpenAIRemoteCompactPath(t *testing.T) {
require.False(t, isOpenAIRemoteCompactPath(nil))
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
require.True(t, isOpenAIRemoteCompactPath(c))
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil)
require.True(t, isOpenAIRemoteCompactPath(c))
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
require.False(t, isOpenAIRemoteCompactPath(c))
}
func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Set(opsModelKey, "gpt-5.3-codex")
c.Set(opsAccountIDKey, int64(123))
c.Header("x-request-id", "rid-compact-ok")
c.Status(http.StatusOK)
h := &OpenAIGatewayHandler{}
h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond))
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded"))
require.True(t, logSink.ContainsFieldValue("status_code", "200"))
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex"))
require.True(t, logSink.ContainsFieldValue("account_id", "123"))
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok"))
}
func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Status(http.StatusBadGateway)
h := &OpenAIGatewayHandler{}
h.logOpenAIRemoteCompactOutcome(c, time.Now())
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed"))
require.True(t, logSink.ContainsFieldValue("status_code", "502"))
require.True(t, logSink.ContainsFieldValue("path", "/responses/compact"))
}
func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
c.Status(http.StatusOK)
h := &OpenAIGatewayHandler{}
h.logOpenAIRemoteCompactOutcome(c, time.Now())
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
}
func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
h := &OpenAIGatewayHandler{}
h.Responses(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
require.True(t, logSink.ContainsFieldValue("status_code", "401"))
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
}

View File

@@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
cfg *config.Config
} }
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -61,6 +62,7 @@ func NewOpenAIGatewayHandler(
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
} }
} }
@@ -70,6 +72,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。 // 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
streamStarted := false streamStarted := false
defer h.recoverResponsesPanic(c, &streamStarted) defer h.recoverResponsesPanic(c, &streamStarted)
compactStartedAt := time.Now()
defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt)
setOpenAIClientTransportHTTP(c) setOpenAIClientTransportHTTP(c)
requestStart := time.Now() requestStart := time.Now()
@@ -340,6 +344,86 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
} }
func isOpenAIRemoteCompactPath(c *gin.Context) bool {
if c == nil || c.Request == nil || c.Request.URL == nil {
return false
}
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
return strings.HasSuffix(normalizedPath, "/responses/compact")
}
func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) {
if !isOpenAIRemoteCompactPath(c) {
return
}
var (
ctx = context.Background()
path string
status int
)
if c != nil {
if c.Request != nil {
ctx = c.Request.Context()
if c.Request.URL != nil {
path = strings.TrimSpace(c.Request.URL.Path)
}
}
if c.Writer != nil {
status = c.Writer.Status()
}
}
outcome := "failed"
if status >= 200 && status < 300 {
outcome = "succeeded"
}
latencyMs := time.Since(startedAt).Milliseconds()
if latencyMs < 0 {
latencyMs = 0
}
fields := []zap.Field{
zap.String("component", "handler.openai_gateway.responses"),
zap.Bool("remote_compact", true),
zap.String("compact_outcome", outcome),
zap.Int("status_code", status),
zap.Int64("latency_ms", latencyMs),
zap.String("path", path),
zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI),
}
if c != nil {
if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" {
fields = append(fields, zap.String("request_user_agent", userAgent))
}
if v, ok := c.Get(opsModelKey); ok {
if model, ok := v.(string); ok && strings.TrimSpace(model) != "" {
fields = append(fields, zap.String("request_model", strings.TrimSpace(model)))
}
}
if v, ok := c.Get(opsAccountIDKey); ok {
if accountID, ok := v.(int64); ok && accountID > 0 {
fields = append(fields, zap.Int64("account_id", accountID))
}
}
if c.Writer != nil {
if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" {
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
} else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" {
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
}
}
}
log := logger.FromContext(ctx).With(fields...)
if outcome == "succeeded" {
log.Info("codex.remote_compact.succeeded")
return
}
log.Warn("codex.remote_compact.failed")
}
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool { func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
return true return true

View File

@@ -32,27 +32,28 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
} }
response.Success(c, dto.PublicSettings{ response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PasswordResetEnabled: settings.PasswordResetEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
TotpEnabled: settings.TotpEnabled, InvitationCodeEnabled: settings.InvitationCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled, TotpEnabled: settings.TotpEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileEnabled: settings.TurnstileEnabled,
SiteName: settings.SiteName, TurnstileSiteKey: settings.TurnstileSiteKey,
SiteLogo: settings.SiteLogo, SiteName: settings.SiteName,
SiteSubtitle: settings.SiteSubtitle, SiteLogo: settings.SiteLogo,
APIBaseURL: settings.APIBaseURL, SiteSubtitle: settings.SiteSubtitle,
ContactInfo: settings.ContactInfo, APIBaseURL: settings.APIBaseURL,
DocURL: settings.DocURL, ContactInfo: settings.ContactInfo,
HomeContent: settings.HomeContent, DocURL: settings.DocURL,
HideCcsImportButton: settings.HideCcsImportButton, HomeContent: settings.HomeContent,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
SoraClientEnabled: settings.SoraClientEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version, SoraClientEnabled: settings.SoraClientEnabled,
Version: h.version,
}) })
} }

View File

@@ -996,7 +996,7 @@ func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*se
} }
func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil } func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil } func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {

View File

@@ -30,6 +30,7 @@ func ProvideAdminHandlers(
userAttributeHandler *admin.UserAttributeHandler, userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler, errorPassthroughHandler *admin.ErrorPassthroughHandler,
apiKeyHandler *admin.AdminAPIKeyHandler, apiKeyHandler *admin.AdminAPIKeyHandler,
scheduledTestHandler *admin.ScheduledTestHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
Dashboard: dashboardHandler, Dashboard: dashboardHandler,
@@ -53,6 +54,7 @@ func ProvideAdminHandlers(
UserAttribute: userAttributeHandler, UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler, ErrorPassthrough: errorPassthroughHandler,
APIKey: apiKeyHandler, APIKey: apiKeyHandler,
ScheduledTest: scheduledTestHandler,
} }
} }
@@ -141,6 +143,7 @@ var ProviderSet = wire.NewSet(
admin.NewUserAttributeHandler, admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler, admin.NewErrorPassthroughHandler,
admin.NewAdminAPIKeyHandler, admin.NewAdminAPIKeyHandler,
admin.NewScheduledTestHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
ProvideAdminHandlers, ProvideAdminHandlers,

View File

@@ -57,25 +57,28 @@ type DashboardStats struct {
// TrendDataPoint represents a single point in trend data // TrendDataPoint represents a single point in trend data
type TrendDataPoint struct { type TrendDataPoint struct {
Date string `json:"date"` Date string `json:"date"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"` InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"` OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"` CacheCreationTokens int64 `json:"cache_creation_tokens"`
TotalTokens int64 `json:"total_tokens"` CacheReadTokens int64 `json:"cache_read_tokens"`
Cost float64 `json:"cost"` // 标准计费 TotalTokens int64 `json:"total_tokens"`
ActualCost float64 `json:"actual_cost"` // 实际扣除 Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
} }
// ModelStat represents usage statistics for a single model // ModelStat represents usage statistics for a single model
type ModelStat struct { type ModelStat struct {
Model string `json:"model"` Model string `json:"model"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"` InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"` OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"` CacheCreationTokens int64 `json:"cache_creation_tokens"`
Cost float64 `json:"cost"` // 标准计费 CacheReadTokens int64 `json:"cache_read_tokens"`
ActualCost float64 `json:"actual_cost"` // 实际扣除 TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
} }
// GroupStat represents usage statistics for a single group // GroupStat represents usage statistics for a single group
@@ -154,6 +157,8 @@ type UsageLogFilters struct {
BillingType *int8 BillingType *int8
StartTime *time.Time StartTime *time.Time
EndTime *time.Time EndTime *time.Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
ExactTotal bool
} }
// UsageStats represents usage statistics // UsageStats represents usage statistics

View File

@@ -437,6 +437,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
switch status { switch status {
case "rate_limited": case "rate_limited":
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now())) q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
case "temp_unschedulable":
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.And(
entsql.Not(entsql.IsNull(col)),
entsql.GT(col, entsql.Expr("NOW()")),
))
}))
default: default:
q = q.Where(dbaccount.StatusEQ(status)) q = q.Where(dbaccount.StatusEQ(status))
} }
@@ -640,7 +648,17 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
SetStatus(service.StatusActive). SetStatus(service.StatusActive).
SetErrorMessage(""). SetErrorMessage("").
Save(ctx) Save(ctx)
return err if err != nil {
return err
}
// 清除临时不可调度状态,重置 401 升级链
_, _ = r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = NULL,
temp_unschedulable_reason = NULL
WHERE id = $1 AND deleted_at IS NULL
`, id)
return nil
} }
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {

View File

@@ -281,9 +281,27 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return nil return nil
} }
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID)) q := r.activeQuery().Where(apikey.UserIDEQ(userID))
// Apply filters
if filters.Search != "" {
q = q.Where(apikey.Or(
apikey.NameContainsFold(filters.Search),
apikey.KeyContainsFold(filters.Search),
))
}
if filters.Status != "" {
q = q.Where(apikey.StatusEQ(filters.Status))
}
if filters.GroupID != nil {
if *filters.GroupID == 0 {
q = q.Where(apikey.GroupIDIsNil())
} else {
q = q.Where(apikey.GroupIDEQ(*filters.GroupID))
}
}
total, err := q.Count(ctx) total, err := q.Count(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -158,7 +158,7 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil) s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil) s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
s.Require().NoError(err, "ListByUserID") s.Require().NoError(err, "ListByUserID")
s.Require().Len(keys, 2) s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total) s.Require().Equal(int64(2), page.Total)
@@ -170,7 +170,7 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil) s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
} }
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}, service.APIKeyListFilters{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(keys, 2) s.Require().Len(keys, 2)
s.Require().Equal(int64(5), page.Total) s.Require().Equal(int64(5), page.Total)
@@ -314,7 +314,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(service.StatusDisabled, got2.Status) s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID) s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
s.Require().NoError(err, "ListByUserID") s.Require().NoError(err, "ListByUserID")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(keys, 1) s.Require().Len(keys, 1)

View File

@@ -0,0 +1,183 @@
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// --- Plan Repository ---
type scheduledTestPlanRepository struct {
db *sql.DB
}
func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanRepository {
return &scheduledTestPlanRepository{db: db}
}
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
return scanPlan(row)
}
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans WHERE id = $1
`, id)
return scanPlan(row)
}
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans WHERE account_id = $1
ORDER BY created_at DESC
`, accountID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
return scanPlans(rows)
}
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans
WHERE enabled = true AND next_run_at <= $1
ORDER BY next_run_at ASC
`, now)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
return scanPlans(rows)
}
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
UPDATE scheduled_test_plans
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
WHERE id = $1
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
return scanPlan(row)
}
func (r *scheduledTestPlanRepository) Delete(ctx context.Context, id int64) error {
_, err := r.db.ExecContext(ctx, `DELETE FROM scheduled_test_plans WHERE id = $1`, id)
return err
}
func (r *scheduledTestPlanRepository) UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error {
_, err := r.db.ExecContext(ctx, `
UPDATE scheduled_test_plans SET last_run_at = $2, next_run_at = $3, updated_at = NOW() WHERE id = $1
`, id, lastRunAt, nextRunAt)
return err
}
// --- Result Repository ---
type scheduledTestResultRepository struct {
db *sql.DB
}
func NewScheduledTestResultRepository(db *sql.DB) service.ScheduledTestResultRepository {
return &scheduledTestResultRepository{db: db}
}
func (r *scheduledTestResultRepository) Create(ctx context.Context, result *service.ScheduledTestResult) (*service.ScheduledTestResult, error) {
row := r.db.QueryRowContext(ctx, `
INSERT INTO scheduled_test_results (plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
RETURNING id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
`, result.PlanID, result.Status, result.ResponseText, result.ErrorMessage, result.LatencyMs, result.StartedAt, result.FinishedAt)
out := &service.ScheduledTestResult{}
if err := row.Scan(
&out.ID, &out.PlanID, &out.Status, &out.ResponseText, &out.ErrorMessage,
&out.LatencyMs, &out.StartedAt, &out.FinishedAt, &out.CreatedAt,
); err != nil {
return nil, err
}
return out, nil
}
func (r *scheduledTestResultRepository) ListByPlanID(ctx context.Context, planID int64, limit int) ([]*service.ScheduledTestResult, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
FROM scheduled_test_results
WHERE plan_id = $1
ORDER BY created_at DESC
LIMIT $2
`, planID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var results []*service.ScheduledTestResult
for rows.Next() {
r := &service.ScheduledTestResult{}
if err := rows.Scan(
&r.ID, &r.PlanID, &r.Status, &r.ResponseText, &r.ErrorMessage,
&r.LatencyMs, &r.StartedAt, &r.FinishedAt, &r.CreatedAt,
); err != nil {
return nil, err
}
results = append(results, r)
}
return results, rows.Err()
}
func (r *scheduledTestResultRepository) PruneOldResults(ctx context.Context, planID int64, keepCount int) error {
_, err := r.db.ExecContext(ctx, `
DELETE FROM scheduled_test_results
WHERE id IN (
SELECT id FROM (
SELECT id, ROW_NUMBER() OVER (PARTITION BY plan_id ORDER BY created_at DESC) AS rn
FROM scheduled_test_results
WHERE plan_id = $1
) ranked
WHERE rn > $2
)
`, planID, keepCount)
return err
}
// --- scan helpers ---
type scannable interface {
Scan(dest ...any) error
}
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
p := &service.ScheduledTestPlan{}
if err := row.Scan(
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, err
}
return p, nil
}
func scanPlans(rows *sql.Rows) ([]*service.ScheduledTestPlan, error) {
var plans []*service.ScheduledTestPlan
for rows.Next() {
p, err := scanPlan(rows)
if err != nil {
return nil, err
}
plans = append(plans, p)
}
return plans, rows.Err()
}

View File

@@ -122,7 +122,7 @@ func (s *SettingRepoSuite) TestSet_EmptyValue() {
func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() { func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
// 模拟保存站点设置,部分字段有值,部分字段为空 // 模拟保存站点设置,部分字段有值,部分字段为空
settings := map[string]string{ settings := map[string]string{
"site_name": "AICodex2API", "site_name": "Sub2api",
"site_subtitle": "Subscription to API", "site_subtitle": "Subscription to API",
"site_logo": "", // 用户未上传Logo "site_logo": "", // 用户未上传Logo
"api_base_url": "", // 用户未设置API地址 "api_base_url": "", // 用户未设置API地址
@@ -136,7 +136,7 @@ func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"}) result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"})
s.Require().NoError(err, "GetMultiple after SetMultiple with empty values") s.Require().NoError(err, "GetMultiple after SetMultiple with empty values")
s.Require().Equal("AICodex2API", result["site_name"]) s.Require().Equal("Sub2api", result["site_name"])
s.Require().Equal("Subscription to API", result["site_subtitle"]) s.Require().Equal("Subscription to API", result["site_subtitle"])
s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved") s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved")
s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved") s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved")

View File

@@ -1363,7 +1363,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens, COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost COALESCE(SUM(actual_cost), 0) as actual_cost
@@ -1401,6 +1402,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost COALESCE(SUM(actual_cost), 0) as actual_cost
@@ -1473,7 +1476,16 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
} }
whereClause := buildWhere(conditions) whereClause := buildWhere(conditions)
logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params) var (
logs []service.UsageLog
page *pagination.PaginationResult
err error
)
if shouldUseFastUsageLogTotal(filters) {
logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params)
} else {
logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params)
}
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -1484,17 +1496,45 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
return logs, page, nil return logs, page, nil
} }
func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool {
if filters.ExactTotal {
return false
}
// 强选择过滤下记录集通常较小,保留精确总数。
return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0
}
// UsageStats represents usage statistics // UsageStats represents usage statistics
type UsageStats = usagestats.UsageStats type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user // BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats = usagestats.BatchUserUsageStats type BatchUserUsageStats = usagestats.BatchUserUsageStats
func normalizePositiveInt64IDs(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
seen := make(map[int64]struct{}, len(ids))
out := make([]int64, 0, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
return out
}
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. // GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
// If startTime is zero, defaults to 30 days ago. // If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
result := make(map[int64]*BatchUserUsageStats) result := make(map[int64]*BatchUserUsageStats)
if len(userIDs) == 0 { normalizedUserIDs := normalizePositiveInt64IDs(userIDs)
if len(normalizedUserIDs) == 0 {
return result, nil return result, nil
} }
@@ -1506,58 +1546,36 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
endTime = time.Now() endTime = time.Now()
} }
for _, id := range userIDs { for _, id := range normalizedUserIDs {
result[id] = &BatchUserUsageStats{UserID: id} result[id] = &BatchUserUsageStats{UserID: id}
} }
query := ` query := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost SELECT
user_id,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
FROM usage_logs FROM usage_logs
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3 WHERE user_id = ANY($1)
AND created_at >= LEAST($2, $4)
GROUP BY user_id GROUP BY user_id
` `
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime) today := timezone.Today()
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for rows.Next() { for rows.Next() {
var userID int64 var userID int64
var total float64 var total float64
if err := rows.Scan(&userID, &total); err != nil { var todayTotal float64
if err := rows.Scan(&userID, &total, &todayTotal); err != nil {
_ = rows.Close() _ = rows.Close()
return nil, err return nil, err
} }
if stats, ok := result[userID]; ok { if stats, ok := result[userID]; ok {
stats.TotalActualCost = total stats.TotalActualCost = total
} stats.TodayActualCost = todayTotal
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost
FROM usage_logs
WHERE user_id = ANY($1) AND created_at >= $2
GROUP BY user_id
`
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today)
if err != nil {
return nil, err
}
for rows.Next() {
var userID int64
var total float64
if err := rows.Scan(&userID, &total); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[userID]; ok {
stats.TodayActualCost = total
} }
} }
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
@@ -1577,7 +1595,8 @@ type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// If startTime is zero, defaults to 30 days ago. // If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats) result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 { normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs)
if len(normalizedAPIKeyIDs) == 0 {
return result, nil return result, nil
} }
@@ -1589,58 +1608,36 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
endTime = time.Now() endTime = time.Now()
} }
for _, id := range apiKeyIDs { for _, id := range normalizedAPIKeyIDs {
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
} }
query := ` query := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost SELECT
api_key_id,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
FROM usage_logs FROM usage_logs
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3 WHERE api_key_id = ANY($1)
AND created_at >= LEAST($2, $4)
GROUP BY api_key_id GROUP BY api_key_id
` `
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime) today := timezone.Today()
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for rows.Next() { for rows.Next() {
var apiKeyID int64 var apiKeyID int64
var total float64 var total float64
if err := rows.Scan(&apiKeyID, &total); err != nil { var todayTotal float64
if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil {
_ = rows.Close() _ = rows.Close()
return nil, err return nil, err
} }
if stats, ok := result[apiKeyID]; ok { if stats, ok := result[apiKeyID]; ok {
stats.TotalActualCost = total stats.TotalActualCost = total
} stats.TodayActualCost = todayTotal
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost
FROM usage_logs
WHERE api_key_id = ANY($1) AND created_at >= $2
GROUP BY api_key_id
`
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today)
if err != nil {
return nil, err
}
for rows.Next() {
var apiKeyID int64
var total float64
if err := rows.Scan(&apiKeyID, &total); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[apiKeyID]; ok {
stats.TodayActualCost = total
} }
} }
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
@@ -1670,7 +1667,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens, COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost COALESCE(SUM(actual_cost), 0) as actual_cost
@@ -1753,7 +1751,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
total_requests as requests, total_requests as requests,
input_tokens, input_tokens,
output_tokens, output_tokens,
(cache_creation_tokens + cache_read_tokens) as cache_tokens, cache_creation_tokens,
cache_read_tokens,
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
total_cost as cost, total_cost as cost,
actual_cost actual_cost
@@ -1768,7 +1767,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
total_requests as requests, total_requests as requests,
input_tokens, input_tokens,
output_tokens, output_tokens,
(cache_creation_tokens + cache_read_tokens) as cache_tokens, cache_creation_tokens,
cache_read_tokens,
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
total_cost as cost, total_cost as cost,
actual_cost actual_cost
@@ -1812,6 +1812,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(total_cost), 0) as cost,
%s %s
@@ -2245,6 +2247,35 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
return logs, paginationResultFromTotal(total, params), nil return logs, paginationResultFromTotal(total, params), nil
} }
func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
limit := params.Limit()
offset := params.Offset()
limitPos := len(args) + 1
offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), limit+1, offset)
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil {
return nil, nil, err
}
hasMore := false
if len(logs) > limit {
hasMore = true
logs = logs[:limit]
}
total := int64(offset) + int64(len(logs))
if hasMore {
// 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。
total = int64(offset) + int64(limit) + 1
}
return logs, paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
@@ -2599,7 +2630,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
&row.Requests, &row.Requests,
&row.InputTokens, &row.InputTokens,
&row.OutputTokens, &row.OutputTokens,
&row.CacheTokens, &row.CacheCreationTokens,
&row.CacheReadTokens,
&row.TotalTokens, &row.TotalTokens,
&row.Cost, &row.Cost,
&row.ActualCost, &row.ActualCost,
@@ -2623,6 +2655,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
&row.Requests, &row.Requests,
&row.InputTokens, &row.InputTokens,
&row.OutputTokens, &row.OutputTokens,
&row.CacheCreationTokens,
&row.CacheReadTokens,
&row.TotalTokens, &row.TotalTokens,
&row.Cost, &row.Cost,
&row.ActualCost, &row.ActualCost,

View File

@@ -96,6 +96,7 @@ func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
filters := usagestats.UsageLogFilters{ filters := usagestats.UsageLogFilters{
RequestType: &requestType, RequestType: &requestType,
Stream: &stream, Stream: &stream,
ExactTotal: true,
} }
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
@@ -124,7 +125,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)"). mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
WithArgs(start, end, requestType). WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"})) WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil) trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
require.NoError(t, err) require.NoError(t, err)
@@ -143,7 +144,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(start, end, requestType). WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"})) WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil) stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -243,21 +243,24 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
userMap[u.ID] = &outUsers[len(outUsers)-1] userMap[u.ID] = &outUsers[len(outUsers)-1]
} }
// Batch load active subscriptions with groups to avoid N+1. shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions
subs, err := r.client.UserSubscription.Query(). if shouldLoadSubscriptions {
Where( // Batch load active subscriptions with groups to avoid N+1.
usersubscription.UserIDIn(userIDs...), subs, err := r.client.UserSubscription.Query().
usersubscription.StatusEQ(service.SubscriptionStatusActive), Where(
). usersubscription.UserIDIn(userIDs...),
WithGroup(). usersubscription.StatusEQ(service.SubscriptionStatusActive),
All(ctx) ).
if err != nil { WithGroup().
return nil, nil, err All(ctx)
} if err != nil {
return nil, nil, err
}
for i := range subs { for i := range subs {
if u, ok := userMap[subs[i].UserID]; ok { if u, ok := userMap[subs[i].UserID]; ok {
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i])) u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
}
} }
} }

View File

@@ -53,7 +53,9 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository, NewAPIKeyRepository,
NewGroupRepository, NewGroupRepository,
NewAccountRepository, NewAccountRepository,
NewSoraAccountRepository, // Sora 账号扩展表仓储 NewSoraAccountRepository, // Sora 账号扩展表仓储
NewScheduledTestPlanRepository, // 定时测试计划仓储
NewScheduledTestResultRepository, // 定时测试结果仓储
NewProxyRepository, NewProxyRepository,
NewRedeemCodeRepository, NewRedeemCodeRepository,
NewPromoCodeRepository, NewPromoCodeRepository,

View File

@@ -446,9 +446,10 @@ func TestAPIContracts(t *testing.T) {
setup: func(t *testing.T, deps *contractDeps) { setup: func(t *testing.T, deps *contractDeps) {
t.Helper() t.Helper()
deps.settingRepo.SetAll(map[string]string{ deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true", service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false", service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true", service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeySMTPHost: "smtp.example.com", service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587", service.SettingKeySMTPPort: "587",
@@ -487,6 +488,7 @@ func TestAPIContracts(t *testing.T) {
"data": { "data": {
"registration_enabled": true, "registration_enabled": true,
"email_verify_enabled": false, "email_verify_enabled": false,
"registration_email_suffix_whitelist": [],
"promo_code_enabled": true, "promo_code_enabled": true,
"password_reset_enabled": false, "password_reset_enabled": false,
"totp_enabled": false, "totp_enabled": false,
@@ -1411,7 +1413,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return nil return nil
} }
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID)) ids := make([]int64, 0, len(r.byID))
for id := range r.byID { for id := range r.byID {
if r.byID[id].UserID == userID { if r.byID[id].UserID == userID {

View File

@@ -56,7 +56,7 @@ func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error { func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {

View File

@@ -537,7 +537,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }

View File

@@ -78,6 +78,9 @@ func RegisterAdminRoutes(
// API Key 管理 // API Key 管理
registerAdminAPIKeyRoutes(admin, h) registerAdminAPIKeyRoutes(admin, h)
// 定时测试计划
registerScheduledTestRoutes(admin, h)
} }
} }
@@ -168,6 +171,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth) ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
// Dashboard (vNext - raw path for MVP) // Dashboard (vNext - raw path for MVP)
ops.GET("/dashboard/snapshot-v2", h.Admin.Ops.GetDashboardSnapshotV2)
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview) ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend) ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram) ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
@@ -180,6 +184,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard := admin.Group("/dashboard") dashboard := admin.Group("/dashboard")
{ {
dashboard.GET("/snapshot-v2", h.Admin.Dashboard.GetSnapshotV2)
dashboard.GET("/stats", h.Admin.Dashboard.GetStats) dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
@@ -476,6 +481,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
plans := admin.Group("/scheduled-test-plans")
{
plans.POST("", h.Admin.ScheduledTest.Create)
plans.PUT("/:id", h.Admin.ScheduledTest.Update)
plans.DELETE("/:id", h.Admin.ScheduledTest.Delete)
plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults)
}
// Nested under accounts
admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount)
}
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
rules := admin.Group("/error-passthrough-rules") rules := admin.Group("/error-passthrough-rules")
{ {

View File

@@ -853,15 +853,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
} }
const ( const (
OpenAIWSIngressModeOff = "off" OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared" OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated" OpenAIWSIngressModeDedicated = "dedicated"
OpenAIWSIngressModeCtxPool = "ctx_pool"
OpenAIWSIngressModePassthrough = "passthrough"
) )
func normalizeOpenAIWSIngressMode(mode string) string { func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) { switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff: case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff return OpenAIWSIngressModeOff
case OpenAIWSIngressModeCtxPool:
return OpenAIWSIngressModeCtxPool
case OpenAIWSIngressModePassthrough:
return OpenAIWSIngressModePassthrough
case OpenAIWSIngressModeShared: case OpenAIWSIngressModeShared:
return OpenAIWSIngressModeShared return OpenAIWSIngressModeShared
case OpenAIWSIngressModeDedicated: case OpenAIWSIngressModeDedicated:
@@ -873,18 +879,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
func normalizeOpenAIWSIngressDefaultMode(mode string) string { func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" { if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
return OpenAIWSIngressModeCtxPool
}
return normalized return normalized
} }
return OpenAIWSIngressModeShared return OpenAIWSIngressModeCtxPool
} }
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式off/shared/dedicated)。 // ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式off/ctx_pool/passthrough)。
// //
// 优先级: // 优先级:
// 1. 分类型 mode 新字段string // 1. 分类型 mode 新字段string
// 2. 分类型 enabled 旧字段bool // 2. 分类型 enabled 旧字段bool
// 3. 兼容 enabled 旧字段bool // 3. 兼容 enabled 旧字段bool
// 4. defaultMode非法时回退 shared // 4. defaultMode非法时回退 ctx_pool
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string { func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode) resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() { if a == nil || !a.IsOpenAI() {
@@ -919,7 +928,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return "", false return "", false
} }
if enabled { if enabled {
return OpenAIWSIngressModeShared, true return OpenAIWSIngressModeCtxPool, true
} }
return OpenAIWSIngressModeOff, true return OpenAIWSIngressModeOff, true
} }
@@ -946,6 +955,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok { if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
return mode return mode
} }
// 兼容旧值shared/dedicated 语义都归并到 ctx_pool。
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
return OpenAIWSIngressModeCtxPool
}
return resolvedDefault return resolvedDefault
} }

View File

@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
} }
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
t.Run("default fallback to shared", func(t *testing.T) { t.Run("default fallback to ctx_pool", func(t *testing.T) {
account := &Account{ account := &Account{
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
Extra: map[string]any{}, Extra: map[string]any{},
} }
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("")) require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
}) })
t.Run("oauth mode field has highest priority", func(t *testing.T) { t.Run("oauth mode field has highest priority", func(t *testing.T) {
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
Extra: map[string]any{ Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
"openai_oauth_responses_websockets_v2_enabled": false, "openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": false, "responses_websockets_v2_enabled": false,
}, },
} }
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
}) })
t.Run("legacy enabled maps to shared", func(t *testing.T) { t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
account := &Account{ account := &Account{
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
Type: AccountTypeAPIKey, Type: AccountTypeAPIKey,
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true, "responses_websockets_v2_enabled": true,
}, },
} }
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
shared := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
},
}
dedicated := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
}) })
t.Run("legacy disabled maps to off", func(t *testing.T) { t.Run("legacy disabled maps to off", func(t *testing.T) {
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true, "responses_websockets_v2_enabled": true,
}, },
} }
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
}) })
t.Run("non openai always off", func(t *testing.T) { t.Run("non openai always off", func(t *testing.T) {

View File

@@ -12,6 +12,7 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"regexp" "regexp"
"strings" "strings"
@@ -33,7 +34,7 @@ import (
var sseDataPrefix = regexp.MustCompile(`^data:\s*`) var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const ( const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages" testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
@@ -238,7 +239,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
} }
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages" apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages?beta=true"
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
} }
@@ -1560,3 +1561,62 @@ func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) er
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg}) s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
return fmt.Errorf("%s", errorMsg) return fmt.Errorf("%s", errorMsg)
} }
// RunTestBackground executes an account test in-memory (no real HTTP client),
// capturing SSE output via httptest.NewRecorder, then parses the result.
func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID int64, modelID string) (*ScheduledTestResult, error) {
startedAt := time.Now()
w := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
finishedAt := time.Now()
body := w.Body.String()
responseText, errMsg := parseTestSSEOutput(body)
status := "success"
if testErr != nil || errMsg != "" {
status = "failed"
if errMsg == "" && testErr != nil {
errMsg = testErr.Error()
}
}
return &ScheduledTestResult{
Status: status,
ResponseText: responseText,
ErrorMessage: errMsg,
LatencyMs: finishedAt.Sub(startedAt).Milliseconds(),
StartedAt: startedAt,
FinishedAt: finishedAt,
}, nil
}
// parseTestSSEOutput extracts response text and error message from captured SSE output.
func parseTestSSEOutput(body string) (responseText, errMsg string) {
var texts []string
for _, line := range strings.Split(body, "\n") {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimPrefix(line, "data: ")
var event TestEvent
if err := json.Unmarshal([]byte(jsonStr), &event); err != nil {
continue
}
switch event.Type {
case "content":
if event.Text != "" {
texts = append(texts, event.Text)
}
case "error":
errMsg = event.Error
}
}
responseText = strings.Join(texts, "")
return
}

View File

@@ -745,7 +745,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }

View File

@@ -91,7 +91,7 @@ func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string)
panic("unexpected") panic("unexpected")
} }
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected") panic("unexpected")
} }
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {

View File

@@ -97,3 +97,10 @@ func (k *APIKey) GetDaysUntilExpiry() int {
} }
return int(duration.Hours() / 24) return int(duration.Hours() / 24)
} }
// APIKeyListFilters holds optional filtering parameters for listing API keys.
type APIKeyListFilters struct {
Search string
Status string
GroupID *int64 // nil=不筛选, 0=无分组, >0=指定分组
}

View File

@@ -55,7 +55,7 @@ type APIKeyRepository interface {
Update(ctx context.Context, key *APIKey) error Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error) ExistsByKey(ctx context.Context, key string) (bool, error)
@@ -392,8 +392,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
} }
// List 获取用户的API Key列表 // List 获取用户的API Key列表
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err) return nil, nil, fmt.Errorf("list api keys: %w", err)
} }

View File

@@ -53,7 +53,7 @@ func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call") panic("unexpected Delete call")
} }
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call") panic("unexpected ListByUserID call")
} }

View File

@@ -81,7 +81,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法 // 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call") panic("unexpected ListByUserID call")
} }

View File

@@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/mail" "net/mail"
"strconv"
"strings" "strings"
"time" "time"
@@ -33,6 +34,7 @@ var (
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
@@ -115,6 +117,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
if isReservedEmail(email) { if isReservedEmail(email) {
return "", nil, ErrEmailReserved return "", nil, ErrEmailReserved
} }
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return "", nil, err
}
// 检查是否需要邀请码 // 检查是否需要邀请码
var invitationRedeemCode *RedeemCode var invitationRedeemCode *RedeemCode
@@ -241,6 +246,9 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
if isReservedEmail(email) { if isReservedEmail(email) {
return ErrEmailReserved return ErrEmailReserved
} }
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return err
}
// 检查邮箱是否已存在 // 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
@@ -279,6 +287,9 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
if isReservedEmail(email) { if isReservedEmail(email) {
return nil, ErrEmailReserved return nil, ErrEmailReserved
} }
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return nil, err
}
// 检查邮箱是否已存在 // 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
@@ -624,6 +635,32 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
} }
} }
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
}
whitelist := s.settingService.GetRegistrationEmailSuffixWhitelist(ctx)
if !IsRegistrationEmailSuffixAllowed(email, whitelist) {
return buildEmailSuffixNotAllowedError(whitelist)
}
return nil
}
func buildEmailSuffixNotAllowedError(whitelist []string) error {
if len(whitelist) == 0 {
return ErrEmailSuffixNotAllowed
}
allowed := strings.Join(whitelist, ", ")
return infraerrors.BadRequest(
"EMAIL_SUFFIX_NOT_ALLOWED",
fmt.Sprintf("email suffix is not allowed, allowed suffixes: %s", allowed),
).WithMetadata(map[string]string{
"allowed_suffixes": strings.Join(whitelist, ","),
"allowed_suffix_count": strconv.Itoa(len(whitelist)),
})
}
// ValidateToken 验证JWT token并返回用户声明 // ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token降低 DoS 风险。 // 先做长度校验,尽早拒绝异常超长 token降低 DoS 风险。

View File

@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -231,6 +232,51 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) {
require.ErrorIs(t, err, ErrEmailReserved) require.ErrorIs(t, err, ErrEmailReserved)
} }
func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
}, nil)
_, _, err := service.Register(context.Background(), "user@other.com", "password")
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
appErr := infraerrors.FromError(err)
require.Contains(t, appErr.Message, "@example.com")
require.Contains(t, appErr.Message, "@company.com")
require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason)
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"])
}
func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) {
repo := &userRepoStub{nextID: 8}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`,
}, nil)
_, user, err := service.Register(context.Background(), "user@example.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, int64(8), user.ID)
}
func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
}, nil)
err := service.SendVerifyCode(context.Background(), "user@other.com")
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
appErr := infraerrors.FromError(err)
require.Contains(t, appErr.Message, "@example.com")
require.Contains(t, appErr.Message, "@company.com")
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
}
func TestAuthService_Register_CreateError(t *testing.T) { func TestAuthService_Register_CreateError(t *testing.T) {
repo := &userRepoStub{createErr: errors.New("create failed")} repo := &userRepoStub{createErr: errors.New("create failed")}
service := newAuthService(repo, map[string]string{ service := newAuthService(repo, map[string]string{
@@ -402,7 +448,7 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42} repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{} assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{ service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
}, nil) }, nil)
service.defaultSubAssigner = assigner service.defaultSubAssigner = assigner

View File

@@ -74,11 +74,12 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys // Setting keys
const ( const (
// 注册设置 // 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单JSON 数组)
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
// 邮件服务设置 // 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址

View File

@@ -88,6 +88,49 @@ func TestCheckErrorPolicy(t *testing.T) {
body: []byte(`overloaded service`), body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled, expected: ErrorPolicyTempUnscheduled,
}, },
{
name: "temp_unschedulable_401_first_hit_returns_temp_unscheduled",
account: &Account{
ID: 14,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
account: &Account{
ID: 15,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyNone,
},
{ {
name: "temp_unschedulable_body_miss_returns_none", name: "temp_unschedulable_body_miss_returns_none",
account: &Account{ account: &Account{

View File

@@ -171,8 +171,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
require.NotNil(t, result) require.NotNil(t, result)
require.True(t, result.Stream) require.True(t, result.Stream)
require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体") require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射")
require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("authorization")) require.Empty(t, upstream.lastReq.Header.Get("authorization"))
@@ -190,7 +189,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
require.True(t, ok) require.True(t, ok)
bodyBytes, ok := rawBody.([]byte) bodyBytes, ok := rawBody.([]byte)
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝") require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
require.Equal(t, body, bodyBytes) require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型")
} }
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) { func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
@@ -253,8 +252,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
err := svc.ForwardCountTokens(context.Background(), c, account, parsed) err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体") require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射")
require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("authorization")) require.Empty(t, upstream.lastReq.Header.Get("authorization"))
require.Empty(t, upstream.lastReq.Header.Get("cookie")) require.Empty(t, upstream.lastReq.Header.Get("cookie"))
@@ -263,6 +261,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
require.Empty(t, rec.Header().Get("Set-Cookie")) require.Empty(t, rec.Header().Get("Set-Cookie"))
} }
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
model string
modelMapping map[string]any // nil = 不配置映射
expectedModel string
endpoint string // "messages" or "count_tokens"
}{
{
name: "Forward: 无映射配置时不改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: nil,
expectedModel: "claude-sonnet-4-20250514",
endpoint: "messages",
},
{
name: "Forward: 空映射配置时不改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{},
expectedModel: "claude-sonnet-4-20250514",
endpoint: "messages",
},
{
name: "Forward: 模型不在映射表中时不改写",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
expectedModel: "claude-sonnet-4-20250514",
endpoint: "messages",
},
{
name: "Forward: 精确匹配映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "messages",
},
{
name: "Forward: 通配符映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "messages",
},
{
name: "CountTokens: 无映射配置时不改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: nil,
expectedModel: "claude-sonnet-4-20250514",
endpoint: "count_tokens",
},
{
name: "CountTokens: 模型不在映射表中时不改写",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
expectedModel: "claude-sonnet-4-20250514",
endpoint: "count_tokens",
},
{
name: "CountTokens: 精确匹配映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "count_tokens",
},
{
name: "CountTokens: 通配符映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "count_tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
parsed := &ParsedRequest{
Body: body,
Model: tt.model,
}
credentials := map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
}
if tt.modelMapping != nil {
credentials["model_mapping"] = tt.modelMapping
}
account := &Account{
ID: 300,
Name: "edge-case-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: credentials,
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
if tt.endpoint == "messages" {
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
parsed.Stream = false
upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamJSON)),
},
}
svc := &GatewayService{
cfg: &config.Config{},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
result, err := svc.Forward(context.Background(), c, account, parsed)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
"Forward 上游请求体中的模型应为: %s", tt.expectedModel)
} else {
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
upstreamRespBody := `{"input_tokens":42}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
"CountTokens 上游请求体中的模型应为: %s", tt.expectedModel)
}
})
}
}
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields
// 确保模型映射只替换 model 字段,不影响请求体中的其他字段
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
// 包含复杂字段的请求体system、thinking、messages
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`)
parsed := &ParsedRequest{
Body: body,
Model: "claude-sonnet-4-20250514",
}
upstreamRespBody := `{"input_tokens":42}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 301,
Name: "preserve-fields-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
"model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
},
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
sentBody := upstream.lastBody
require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射")
require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改")
require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改")
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改")
require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改")
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
}
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
// 确保空模型名不会触发映射逻辑
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
parsed := &ParsedRequest{
Body: body,
Model: "", // 空模型
}
upstreamRespBody := `{"input_tokens":10}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 302,
Name: "empty-model-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
"model_mapping": map[string]any{"*": "claude-3-opus-20240229"},
},
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
// 空模型名时body 应原样透传,不应触发映射
require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) { func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@@ -3889,7 +3889,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime) passthroughBody := parsed.Body
passthroughModel := parsed.Model
if passthroughModel != "" {
if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel {
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
passthroughModel = mappedModel
}
}
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
} }
body := parsed.Body body := parsed.Body
@@ -4574,7 +4583,7 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetURL = validatedURL + "/v1/messages" targetURL = validatedURL + "/v1/messages?beta=true"
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
@@ -4954,7 +4963,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetURL = validatedURL + "/v1/messages" targetURL = validatedURL + "/v1/messages?beta=true"
} }
} }
@@ -6781,7 +6790,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body) passthroughBody := parsed.Body
if reqModel := parsed.Model; reqModel != "" {
if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel {
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
}
}
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody)
} }
body := parsed.Body body := parsed.Body
@@ -7072,7 +7088,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetURL = validatedURL + "/v1/messages/count_tokens" targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
@@ -7119,7 +7135,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetURL = validatedURL + "/v1/messages/count_tokens" targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
} }
} }

View File

@@ -122,6 +122,28 @@ func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
body: []byte(`overloaded service`), body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled, expected: ErrorPolicyTempUnscheduled,
}, },
{
name: "gemini_apikey_temp_unschedulable_401_second_hit_returns_none",
account: &Account{
ID: 105,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyNone,
},
{ {
name: "gemini_custom_codes_override_temp_unschedulable", name: "gemini_custom_codes_override_temp_unschedulable",
account: &Account{ account: &Account{

View File

@@ -19,8 +19,10 @@ import (
// 预编译正则表达式(避免每次调用重新编译) // 预编译正则表达式(避免每次调用重新编译)
var ( var (
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid} // 匹配 user_id 格式:
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`) // 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z // 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
) )
@@ -239,13 +241,16 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return body, nil return body, nil
} }
// 匹配格式: user_{64位hex}_account__session_{uuid} // 匹配格式:
// 旧格式: user_{64位hex}_account__session_{uuid}
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
matches := userIDRegex.FindStringSubmatch(userID) matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil { if matches == nil {
return body, nil return body, nil
} }
sessionTail := matches[1] // 原始session UUID // matches[1] = account UUID (可能为空), matches[2] = session UUID
sessionTail := matches[2] // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式 // 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail) seed := fmt.Sprintf("%d::%s", accountID, sessionTail)

View File

@@ -263,13 +263,15 @@ type OpenAIGatewayService struct {
toolCorrector *CodexToolCorrector toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver openaiWSResolver OpenAIWSProtocolResolver
openaiWSPoolOnce sync.Once openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once openaiWSStateStoreOnce sync.Once
openaiSchedulerOnce sync.Once openaiSchedulerOnce sync.Once
openaiWSPool *openAIWSConnPool openaiWSPassthroughDialerOnce sync.Once
openaiWSStateStore OpenAIWSStateStore openaiWSPool *openAIWSConnPool
openaiScheduler OpenAIAccountScheduler openaiWSStateStore OpenAIWSStateStore
openaiAccountStats *openAIAccountRuntimeStats openaiScheduler OpenAIAccountScheduler
openaiWSPassthroughDialer openAIWSClientDialer
openaiAccountStats *openAIAccountRuntimeStats
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiWSRetryMetrics openAIWSRetryMetrics openaiWSRetryMetrics openAIWSRetryMetrics

View File

@@ -11,6 +11,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws "github.com/coder/websocket" coderws "github.com/coder/websocket"
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
) )
@@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct {
conn *coderws.Conn conn *coderws.Conn
} }
var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil)
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error { func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
if c == nil || c.conn == nil { if c == nil || c.conn == nil {
return errOpenAIWSConnClosed return errOpenAIWSConnClosed
@@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
} }
} }
func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.conn == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
msgType, payload, err := c.conn.Read(ctx)
if err != nil {
return coderws.MessageText, nil, err
}
return msgType, payload, nil
}
func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Write(ctx, msgType, payload)
}
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error { func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
if c == nil || c.conn == nil { if c == nil || c.conn == nil {
return errOpenAIWSConnClosed return errOpenAIWSConnClosed

View File

@@ -46,9 +46,10 @@ const (
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024 openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
openAIWSPayloadSizeEstimateMaxItems = 16 openAIWSPayloadSizeEstimateMaxItems = 16
openAIWSEventFlushBatchSizeDefault = 4 openAIWSEventFlushBatchSizeDefault = 4
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
openAIWSPayloadLogSampleDefault = 0.2 openAIWSPayloadLogSampleDefault = 0.2
openAIWSPassthroughIdleTimeoutDefault = time.Hour
openAIWSStoreDisabledConnModeStrict = "strict" openAIWSStoreDisabledConnModeStrict = "strict"
openAIWSStoreDisabledConnModeAdaptive = "adaptive" openAIWSStoreDisabledConnModeAdaptive = "adaptive"
@@ -904,6 +905,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
return s.openaiWSPool return s.openaiWSPool
} }
func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer {
if s == nil {
return nil
}
s.openaiWSPassthroughDialerOnce.Do(func() {
if s.openaiWSPassthroughDialer == nil {
s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer()
}
})
return s.openaiWSPassthroughDialer
}
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
pool := s.getOpenAIWSConnPool() pool := s.getOpenAIWSConnPool()
if pool == nil { if pool == nil {
@@ -967,6 +980,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
return 15 * time.Minute return 15 * time.Minute
} }
func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration {
if timeout := s.openAIWSReadTimeout(); timeout > 0 {
return timeout
}
return openAIWSPassthroughIdleTimeoutDefault
}
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
@@ -2322,7 +2342,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
ingressMode := OpenAIWSIngressModeShared ingressMode := OpenAIWSIngressModeCtxPool
if modeRouterV2Enabled { if modeRouterV2Enabled {
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
if ingressMode == OpenAIWSIngressModeOff { if ingressMode == OpenAIWSIngressModeOff {
@@ -2332,6 +2352,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
nil, nil,
) )
} }
switch ingressMode {
case OpenAIWSIngressModePassthrough:
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
}
return s.proxyResponsesWebSocketV2Passthrough(
ctx,
c,
clientConn,
account,
token,
firstClientMessage,
hooks,
wsDecision,
)
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// continue
default:
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"websocket mode only supports ctx_pool/passthrough",
nil,
)
}
} }
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)

View File

@@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式") require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式") require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) _ = clientConn.Close(coderws.StatusNormalClosure, "done")
select { select {
case serverErr := <-serverErrCh: case serverErr := <-serverErrCh:
@@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe
require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接") require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
} }
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
upstreamConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPassthroughDialer: captureDialer,
}
account := &Account{
ID: 452,
Name: "openai-ingress-passthrough",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
},
}
serverErrCh := make(chan error, 1)
resultCh := make(chan *OpenAIForwardResult, 1)
hooks := &OpenAIWSIngressHooks{
AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
if turnErr == nil && result != nil {
resultCh <- result
}
},
}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
require.Equal(t, "resp_passthrough_turn_1", gjson.GetBytes(event, "response.id").String())
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 passthrough websocket 结束超时")
}
select {
case result := <-resultCh:
require.Equal(t, "resp_passthrough_turn_1", result.RequestID)
require.True(t, result.OpenAIWSMode)
require.Equal(t, 2, result.Usage.InputTokens)
require.Equal(t, 3, result.Usage.OutputTokens)
case <-time.After(2 * time.Second):
t.Fatal("未收到 passthrough turn 结果回调")
}
require.Equal(t, 1, captureDialer.DialCount(), "passthrough 模式应直接建立上游 websocket")
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
return event, nil return event, nil
} }
func (c *openAIWSCaptureConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
payload, err := c.ReadMessage(ctx)
if err != nil {
return coderws.MessageText, nil, err
}
return coderws.MessageText, payload, nil
}
func (c *openAIWSCaptureConn) WriteFrame(ctx context.Context, _ coderws.MessageType, payload []byte) error {
return c.WriteJSON(ctx, json.RawMessage(payload))
}
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error { func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
_ = ctx _ = ctx
return nil return nil

View File

@@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
switch mode { switch mode {
case OpenAIWSIngressModeOff: case OpenAIWSIngressModeOff:
return openAIWSHTTPDecision("account_mode_off") return openAIWSHTTPDecision("account_mode_off")
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough:
// continue // continue
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// 历史值兼容:按 ctx_pool 处理。
mode = OpenAIWSIngressModeCtxPool
default: default:
return openAIWSHTTPDecision("account_mode_off") return openAIWSHTTPDecision("account_mode_off")
} }

View File

@@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
account := &Account{ account := &Account{
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
Concurrency: 1, Concurrency: 1,
Extra: map[string]any{ Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
}, },
} }
t.Run("dedicated mode routes to ws v2", func(t *testing.T) { t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account) decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason) require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
}) })
t.Run("off mode routes to http", func(t *testing.T) { t.Run("off mode routes to http", func(t *testing.T) {
@@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
require.Equal(t, "account_mode_off", decision.Reason) require.Equal(t, "account_mode_off", decision.Reason)
}) })
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) { t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) {
legacyAccount := &Account{ legacyAccount := &Account{
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
Type: AccountTypeAPIKey, Type: AccountTypeAPIKey,
@@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
} }
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount) decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_shared", decision.Reason) require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
})
t.Run("passthrough mode routes to ws v2", func(t *testing.T) {
passthroughAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_passthrough", decision.Reason)
}) })
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) { t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
Extra: map[string]any{ Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
}, },
} }
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency) decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)

View File

@@ -0,0 +1,24 @@
package openai_ws_v2
import (
"context"
)
// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
//
// Reference:
// - Project: caddyserver/caddy (Apache-2.0)
// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
// - Files:
// - modules/caddyhttp/reverseproxy/streaming.go
// - modules/caddyhttp/reverseproxy/reverseproxy.go
func runCaddyStyleRelay(
ctx context.Context,
clientConn FrameConn,
upstreamConn FrameConn,
firstClientMessage []byte,
options RelayOptions,
) (RelayResult, *RelayExit) {
return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options)
}

View File

@@ -0,0 +1,23 @@
package openai_ws_v2
import "context"
// EntryInput 是 passthrough v2 数据面的入口参数。
type EntryInput struct {
Ctx context.Context
ClientConn FrameConn
UpstreamConn FrameConn
FirstClientMessage []byte
Options RelayOptions
}
// RunEntry 是 openai_ws_v2 包对外的统一入口。
func RunEntry(input EntryInput) (RelayResult, *RelayExit) {
return runCaddyStyleRelay(
input.Ctx,
input.ClientConn,
input.UpstreamConn,
input.FirstClientMessage,
input.Options,
)
}

View File

@@ -0,0 +1,29 @@
package openai_ws_v2
import (
"sync/atomic"
)
// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
type MetricsSnapshot struct {
SemanticMutationTotal int64 `json:"semantic_mutation_total"`
UsageParseFailureTotal int64 `json:"usage_parse_failure_total"`
}
var (
// passthrough 路径默认不会做语义改写,该计数通常应保持为 0保留用于未来防御性校验
passthroughSemanticMutationTotal atomic.Int64
passthroughUsageParseFailureTotal atomic.Int64
)
func recordUsageParseFailure() {
passthroughUsageParseFailureTotal.Add(1)
}
// SnapshotMetrics 返回当前 passthrough 指标快照。
func SnapshotMetrics() MetricsSnapshot {
return MetricsSnapshot{
SemanticMutationTotal: passthroughSemanticMutationTotal.Load(),
UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(),
}
}

View File

@@ -0,0 +1,807 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
coderws "github.com/coder/websocket"
"github.com/tidwall/gjson"
)
type FrameConn interface {
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
Close() error
}
type Usage struct {
InputTokens int
OutputTokens int
CacheCreationInputTokens int
CacheReadInputTokens int
}
type RelayResult struct {
RequestModel string
Usage Usage
RequestID string
TerminalEventType string
FirstTokenMs *int
Duration time.Duration
ClientToUpstreamFrames int64
UpstreamToClientFrames int64
DroppedDownstreamFrames int64
}
type RelayTurnResult struct {
RequestModel string
Usage Usage
RequestID string
TerminalEventType string
Duration time.Duration
FirstTokenMs *int
}
type RelayExit struct {
Stage string
Err error
WroteDownstream bool
}
type RelayOptions struct {
WriteTimeout time.Duration
IdleTimeout time.Duration
UpstreamDrainTimeout time.Duration
FirstMessageType coderws.MessageType
OnUsageParseFailure func(eventType string, usageRaw string)
OnTurnComplete func(turn RelayTurnResult)
OnTrace func(event RelayTraceEvent)
Now func() time.Time
}
type RelayTraceEvent struct {
Stage string
Direction string
MessageType string
PayloadBytes int
Graceful bool
WroteDownstream bool
Error string
}
type relayState struct {
usage Usage
requestModel string
lastResponseID string
terminalEventType string
firstTokenMs *int
turnTimingByID map[string]*relayTurnTiming
}
type relayExitSignal struct {
stage string
err error
graceful bool
wroteDownstream bool
}
type observedUpstreamEvent struct {
terminal bool
eventType string
responseID string
usage Usage
duration time.Duration
firstToken *int
}
type relayTurnTiming struct {
startAt time.Time
firstTokenMs *int
}
func Relay(
ctx context.Context,
clientConn FrameConn,
upstreamConn FrameConn,
firstClientMessage []byte,
options RelayOptions,
) (RelayResult, *RelayExit) {
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
if clientConn == nil || upstreamConn == nil {
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
}
if ctx == nil {
ctx = context.Background()
}
nowFn := options.Now
if nowFn == nil {
nowFn = time.Now
}
writeTimeout := options.WriteTimeout
if writeTimeout <= 0 {
writeTimeout = 2 * time.Minute
}
drainTimeout := options.UpstreamDrainTimeout
if drainTimeout <= 0 {
drainTimeout = 1200 * time.Millisecond
}
firstMessageType := options.FirstMessageType
if firstMessageType != coderws.MessageBinary {
firstMessageType = coderws.MessageText
}
startAt := nowFn()
state := &relayState{requestModel: result.RequestModel}
onTrace := options.OnTrace
relayCtx, relayCancel := context.WithCancel(ctx)
defer relayCancel()
lastActivity := atomic.Int64{}
lastActivity.Store(nowFn().UnixNano())
markActivity := func() {
lastActivity.Store(nowFn().UnixNano())
}
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
defer cancel()
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
}
writeClient := func(msgType coderws.MessageType, payload []byte) error {
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
defer cancel()
return clientConn.WriteFrame(writeCtx, msgType, payload)
}
clientToUpstreamFrames := &atomic.Int64{}
upstreamToClientFrames := &atomic.Int64{}
droppedDownstreamFrames := &atomic.Int64{}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_start",
PayloadBytes: len(firstClientMessage),
MessageType: relayMessageTypeString(firstMessageType),
})
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
result.Duration = nowFn().Sub(startAt)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
Error: err.Error(),
})
return result, &RelayExit{Stage: "write_upstream", Err: err}
}
clientToUpstreamFrames.Add(1)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_ok",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
})
markActivity()
exitCh := make(chan relayExitSignal, 3)
dropDownstreamWrites := atomic.Bool{}
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
go runUpstreamToClient(
relayCtx,
upstreamConn,
writeClient,
startAt,
nowFn,
state,
options.OnUsageParseFailure,
options.OnTurnComplete,
&dropDownstreamWrites,
upstreamToClientFrames,
droppedDownstreamFrames,
markActivity,
onTrace,
exitCh,
)
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
firstExit := <-exitCh
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "first_exit",
Direction: relayDirectionFromStage(firstExit.stage),
Graceful: firstExit.graceful,
WroteDownstream: firstExit.wroteDownstream,
Error: relayErrorString(firstExit.err),
})
combinedWroteDownstream := firstExit.wroteDownstream
secondExit := relayExitSignal{graceful: true}
hasSecondExit := false
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
if firstExit.stage == "read_client" && firstExit.graceful {
dropDownstreamWrites.Store(true)
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
} else {
relayCancel()
_ = upstreamConn.Close()
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
}
if hasSecondExit {
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "second_exit",
Direction: relayDirectionFromStage(secondExit.stage),
Graceful: secondExit.graceful,
WroteDownstream: secondExit.wroteDownstream,
Error: relayErrorString(secondExit.err),
})
}
relayCancel()
_ = upstreamConn.Close()
enrichResult(&result, state, nowFn().Sub(startAt))
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
if firstExit.stage == "read_client" && firstExit.graceful {
stage := "client_disconnected"
exitErr := firstExit.err
if hasSecondExit && !secondExit.graceful {
stage = secondExit.stage
exitErr = secondExit.err
}
if exitErr == nil {
exitErr = io.EOF
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(exitErr),
})
return result, &RelayExit{
Stage: stage,
Err: exitErr,
WroteDownstream: combinedWroteDownstream,
}
}
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
_ = clientConn.Close()
return result, nil
}
if !firstExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(firstExit.stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(firstExit.err),
})
return result, &RelayExit{
Stage: firstExit.stage,
Err: firstExit.err,
WroteDownstream: combinedWroteDownstream,
}
}
if hasSecondExit && !secondExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(secondExit.stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(secondExit.err),
})
return result, &RelayExit{
Stage: secondExit.stage,
Err: secondExit.err,
WroteDownstream: combinedWroteDownstream,
}
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
_ = clientConn.Close()
return result, nil
}
func runClientToUpstream(
ctx context.Context,
clientConn FrameConn,
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
markActivity func(),
forwardedFrames *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
for {
msgType, payload, err := clientConn.ReadFrame(ctx)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_client_failed",
Direction: "client_to_upstream",
Error: err.Error(),
Graceful: isDisconnectError(err),
})
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
return
}
markActivity()
if err := writeUpstream(msgType, payload); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_upstream_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
Error: err.Error(),
})
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
return
}
if forwardedFrames != nil {
forwardedFrames.Add(1)
}
markActivity()
}
}
func runUpstreamToClient(
ctx context.Context,
upstreamConn FrameConn,
writeClient func(msgType coderws.MessageType, payload []byte) error,
startAt time.Time,
nowFn func() time.Time,
state *relayState,
onUsageParseFailure func(eventType string, usageRaw string),
onTurnComplete func(turn RelayTurnResult),
dropDownstreamWrites *atomic.Bool,
forwardedFrames *atomic.Int64,
droppedFrames *atomic.Int64,
markActivity func(),
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
wroteDownstream := false
for {
msgType, payload, err := upstreamConn.ReadFrame(ctx)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_upstream_failed",
Direction: "upstream_to_client",
Error: err.Error(),
Graceful: isDisconnectError(err),
WroteDownstream: wroteDownstream,
})
exitCh <- relayExitSignal{
stage: "read_upstream",
err: err,
graceful: isDisconnectError(err),
wroteDownstream: wroteDownstream,
}
return
}
markActivity()
observedEvent := observedUpstreamEvent{}
switch msgType {
case coderws.MessageText:
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
case coderws.MessageBinary:
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
}
emitTurnComplete(onTurnComplete, state, observedEvent)
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
if droppedFrames != nil {
droppedFrames.Add(1)
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "drop_downstream_frame",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
})
if observedEvent.terminal {
exitCh <- relayExitSignal{
stage: "drain_terminal",
graceful: true,
wroteDownstream: wroteDownstream,
}
return
}
markActivity()
continue
}
if err := writeClient(msgType, payload); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_client_failed",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
Error: err.Error(),
})
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
return
}
wroteDownstream = true
if forwardedFrames != nil {
forwardedFrames.Add(1)
}
markActivity()
}
}
func runIdleWatchdog(
ctx context.Context,
nowFn func() time.Time,
idleTimeout time.Duration,
lastActivity *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
if idleTimeout <= 0 {
return
}
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
if checkInterval < time.Second {
checkInterval = time.Second
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
last := time.Unix(0, lastActivity.Load())
if nowFn().Sub(last) < idleTimeout {
continue
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "idle_timeout_triggered",
Direction: "watchdog",
Error: context.DeadlineExceeded.Error(),
})
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
return
}
}
}
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
if onTrace == nil {
return
}
onTrace(event)
}
func relayMessageTypeString(msgType coderws.MessageType) string {
switch msgType {
case coderws.MessageText:
return "text"
case coderws.MessageBinary:
return "binary"
default:
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
}
}
func relayDirectionFromStage(stage string) string {
switch stage {
case "read_client", "write_upstream":
return "client_to_upstream"
case "read_upstream", "write_client", "drain_terminal":
return "upstream_to_client"
case "idle_timeout":
return "watchdog"
default:
return ""
}
}
func relayErrorString(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func observeUpstreamMessage(
state *relayState,
message []byte,
startAt time.Time,
nowFn func() time.Time,
onUsageParseFailure func(eventType string, usageRaw string),
) observedUpstreamEvent {
if state == nil || len(message) == 0 {
return observedUpstreamEvent{}
}
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
eventType := strings.TrimSpace(values[0].String())
if eventType == "" {
return observedUpstreamEvent{}
}
responseID := strings.TrimSpace(values[1].String())
if responseID == "" {
responseID = strings.TrimSpace(values[2].String())
}
// 仅 terminal 事件兜底读取顶层 id避免把 event_id 当成 response_id 关联到 turn。
if responseID == "" && isTerminalEvent(eventType) {
responseID = strings.TrimSpace(values[3].String())
}
now := nowFn()
if state.firstTokenMs == nil && isTokenEvent(eventType) {
ms := int(now.Sub(startAt).Milliseconds())
if ms >= 0 {
state.firstTokenMs = &ms
}
}
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
observed := observedUpstreamEvent{
eventType: eventType,
responseID: responseID,
usage: parsedUsage,
}
if responseID != "" {
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
if ms >= 0 {
turnTiming.firstTokenMs = &ms
}
}
}
if !isTerminalEvent(eventType) {
return observed
}
observed.terminal = true
state.terminalEventType = eventType
if responseID != "" {
state.lastResponseID = responseID
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
duration := now.Sub(turnTiming.startAt)
if duration < 0 {
duration = 0
}
observed.duration = duration
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
}
}
return observed
}
func emitTurnComplete(
onTurnComplete func(turn RelayTurnResult),
state *relayState,
observed observedUpstreamEvent,
) {
if onTurnComplete == nil || !observed.terminal {
return
}
responseID := strings.TrimSpace(observed.responseID)
if responseID == "" {
return
}
requestModel := ""
if state != nil {
requestModel = state.requestModel
}
onTurnComplete(RelayTurnResult{
RequestModel: requestModel,
Usage: observed.usage,
RequestID: responseID,
TerminalEventType: observed.eventType,
Duration: observed.duration,
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
})
}
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
if state == nil {
return nil
}
if state.turnTimingByID == nil {
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
}
timing, ok := state.turnTimingByID[responseID]
if !ok || timing == nil || timing.startAt.IsZero() {
timing = &relayTurnTiming{startAt: now}
state.turnTimingByID[responseID] = timing
return timing
}
return timing
}
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
if state == nil || state.turnTimingByID == nil {
return relayTurnTiming{}, false
}
timing, ok := state.turnTimingByID[responseID]
if !ok || timing == nil {
return relayTurnTiming{}, false
}
delete(state.turnTimingByID, responseID)
return *timing, true
}
func openAIWSRelayCloneIntPtr(v *int) *int {
if v == nil {
return nil
}
cloned := *v
return &cloned
}
func parseUsageAndAccumulate(
state *relayState,
message []byte,
eventType string,
onParseFailure func(eventType string, usageRaw string),
) Usage {
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
return Usage{}
}
usageResult := gjson.GetBytes(message, "response.usage")
if !usageResult.Exists() {
return Usage{}
}
usageRaw := strings.TrimSpace(usageResult.Raw)
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
recordUsageParseFailure()
if onParseFailure != nil {
onParseFailure(eventType, usageRaw)
}
return Usage{}
}
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
inputTokens, inputOK := parseUsageIntField(inputResult, true)
outputTokens, outputOK := parseUsageIntField(outputResult, true)
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
if !inputOK || !outputOK || !cachedOK {
recordUsageParseFailure()
if onParseFailure != nil {
onParseFailure(eventType, usageRaw)
}
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
return Usage{}
}
parsedUsage := Usage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheReadInputTokens: cachedTokens,
}
state.usage.InputTokens += parsedUsage.InputTokens
state.usage.OutputTokens += parsedUsage.OutputTokens
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
return parsedUsage
}
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
if !value.Exists() {
return 0, !required
}
if value.Type != gjson.Number {
return 0, false
}
return int(value.Int()), true
}
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
if result == nil {
return
}
result.Duration = duration
if state == nil {
return
}
result.RequestModel = state.requestModel
result.Usage = state.usage
result.RequestID = state.lastResponseID
result.TerminalEventType = state.terminalEventType
result.FirstTokenMs = state.firstTokenMs
}
func isDisconnectError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
return true
}
switch coderws.CloseStatus(err) {
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
return true
}
message := strings.ToLower(strings.TrimSpace(err.Error()))
if message == "" {
return false
}
return strings.Contains(message, "failed to read frame header: eof") ||
strings.Contains(message, "unexpected eof") ||
strings.Contains(message, "use of closed network connection") ||
strings.Contains(message, "connection reset by peer") ||
strings.Contains(message, "broken pipe")
}
func isTerminalEvent(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
}
}
func shouldParseUsage(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed":
return true
default:
return false
}
}
func isTokenEvent(eventType string) bool {
if eventType == "" {
return false
}
switch eventType {
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
return false
}
if strings.Contains(eventType, ".delta") {
return true
}
if strings.HasPrefix(eventType, "response.output_text") {
return true
}
if strings.HasPrefix(eventType, "response.output") {
return true
}
return eventType == "response.completed" || eventType == "response.done"
}
func minDuration(a, b time.Duration) time.Duration {
if a <= 0 {
return b
}
if b <= 0 {
return a
}
if a < b {
return a
}
return b
}
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
if timeout <= 0 {
timeout = 200 * time.Millisecond
}
select {
case sig := <-exitCh:
return sig, true
case <-time.After(timeout):
return relayExitSignal{}, false
}
}

View File

@@ -0,0 +1,432 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestRunEntry_DelegatesRelay(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
result, relayExit := RunEntry(EntryInput{
Ctx: context.Background(),
ClientConn: clientConn,
UpstreamConn: upstreamConn,
FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`),
})
require.Nil(t, relayExit)
require.Equal(t, "resp_entry", result.RequestID)
}
func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
t.Parallel()
t.Run("read client eof", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn(nil, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
nil,
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_client", sig.stage)
require.True(t, sig.graceful)
})
t.Run("write upstream failed", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
func() {},
nil,
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "write_upstream", sig.stage)
require.False(t, sig.graceful)
})
t.Run("forwarded counter and trace callback", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
forwarded := &atomic.Int64{}
traces := make([]RelayTraceEvent, 0, 2)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
forwarded,
func(event RelayTraceEvent) {
traces = append(traces, event)
},
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_client", sig.stage)
require.Equal(t, int64(1), forwarded.Load())
require.NotEmpty(t, traces)
})
}
func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
t.Parallel()
t.Run("read upstream eof", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(false)
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn(nil, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
nil,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_upstream", sig.stage)
require.True(t, sig.graceful)
})
t.Run("write client failed", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(false)
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
nil,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "write_client", sig.stage)
})
t.Run("drop downstream and stop on terminal", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(true)
dropped := &atomic.Int64{}
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
dropped,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "drain_terminal", sig.stage)
require.True(t, sig.graceful)
require.Equal(t, int64(1), dropped.Load())
})
}
func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
lastActivity := &atomic.Int64{}
lastActivity.Store(time.Now().UnixNano())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh)
select {
case <-exitCh:
t.Fatal("unexpected idle timeout signal")
case <-time.After(200 * time.Millisecond):
}
}
func TestHelperFunctionsCoverage(t *testing.T) {
t.Parallel()
require.Equal(t, "text", relayMessageTypeString(coderws.MessageText))
require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary))
require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(")
require.Equal(t, "", relayErrorString(nil))
require.Equal(t, "x", relayErrorString(errors.New("x")))
require.True(t, isDisconnectError(io.EOF))
require.True(t, isDisconnectError(net.ErrClosed))
require.True(t, isDisconnectError(context.Canceled))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway}))
require.True(t, isDisconnectError(errors.New("broken pipe")))
require.False(t, isDisconnectError(errors.New("unrelated")))
require.True(t, isTokenEvent("response.output_text.delta"))
require.True(t, isTokenEvent("response.output_audio.delta"))
require.True(t, isTokenEvent("response.completed"))
require.False(t, isTokenEvent(""))
require.False(t, isTokenEvent("response.created"))
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second))
require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second))
require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second))
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0))
ch := make(chan relayExitSignal, 1)
ch <- relayExitSignal{stage: "ok"}
sig, ok := waitRelayExit(ch, 10*time.Millisecond)
require.True(t, ok)
require.Equal(t, "ok", sig.stage)
ch <- relayExitSignal{stage: "ok2"}
sig, ok = waitRelayExit(ch, 0)
require.True(t, ok)
require.Equal(t, "ok2", sig.stage)
_, ok = waitRelayExit(ch, 10*time.Millisecond)
require.False(t, ok)
n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true)
require.True(t, ok)
require.Equal(t, 3, n)
_, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true)
require.False(t, ok)
n, ok = parseUsageIntField(gjson.Result{}, false)
require.True(t, ok)
require.Equal(t, 0, n)
_, ok = parseUsageIntField(gjson.Result{}, true)
require.False(t, ok)
}
func TestParseUsageAndEnrichCoverage(t *testing.T) {
t.Parallel()
state := &relayState{}
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil)
require.Equal(t, 0, state.usage.InputTokens)
parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`),
"response.completed",
nil,
)
require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage")
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`),
"response.completed",
nil,
)
require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage")
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
require.Equal(t, 2, state.usage.InputTokens)
require.Equal(t, 1, state.usage.OutputTokens)
require.Equal(t, 1, state.usage.CacheReadInputTokens)
result := &RelayResult{}
enrichResult(result, state, 5*time.Millisecond)
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
require.Equal(t, 5*time.Millisecond, result.Duration)
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
require.Equal(t, 2, state.usage.InputTokens)
enrichResult(nil, state, 0)
}
func TestEmitTurnCompleteCoverage(t *testing.T) {
t.Parallel()
// 非 terminal 事件不应触发。
called := 0
emitTurnComplete(func(turn RelayTurnResult) {
called++
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
terminal: false,
eventType: "response.output_text.delta",
responseID: "resp_ignored",
usage: Usage{InputTokens: 1},
})
require.Equal(t, 0, called)
// 缺少 response_id 时不应触发。
emitTurnComplete(func(turn RelayTurnResult) {
called++
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
terminal: true,
eventType: "response.completed",
})
require.Equal(t, 0, called)
// terminal 且 response_id 存在应该触发state=nil 时 model 为空串。
var got RelayTurnResult
emitTurnComplete(func(turn RelayTurnResult) {
called++
got = turn
}, nil, observedUpstreamEvent{
terminal: true,
eventType: "response.completed",
responseID: "resp_emit",
usage: Usage{InputTokens: 2, OutputTokens: 3},
})
require.Equal(t, 1, called)
require.Equal(t, "resp_emit", got.RequestID)
require.Equal(t, "response.completed", got.TerminalEventType)
require.Equal(t, 2, got.Usage.InputTokens)
require.Equal(t, 3, got.Usage.OutputTokens)
require.Equal(t, "", got.RequestModel)
}
func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) {
t.Parallel()
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure}))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd}))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure}))
require.True(t, isDisconnectError(errors.New("connection reset by peer")))
require.False(t, isDisconnectError(errors.New(" ")))
}
func TestIsTokenEventCoverageBranches(t *testing.T) {
t.Parallel()
require.False(t, isTokenEvent("response.in_progress"))
require.False(t, isTokenEvent("response.output_item.added"))
require.True(t, isTokenEvent("response.output_audio.delta"))
require.True(t, isTokenEvent("response.output"))
require.True(t, isTokenEvent("response.done"))
}
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
t.Parallel()
now := time.Unix(100, 0)
// nil state
require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now))
_, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil")
require.False(t, ok)
state := &relayState{}
timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now)
require.NotNil(t, timing)
require.Equal(t, now, timing.startAt)
// 再次获取返回同一条 timing
timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second))
require.NotNil(t, timing2)
require.Equal(t, now, timing2.startAt)
// 删除存在键
deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a")
require.True(t, ok)
require.Equal(t, now, deleted.startAt)
// 删除不存在键
_, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a")
require.False(t, ok)
}
func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) {
t.Parallel()
state := &relayState{requestModel: "gpt-5"}
startAt := time.Unix(0, 0)
now := startAt
nowFn := func() time.Time {
now = now.Add(5 * time.Millisecond)
return now
}
// 非 terminal仅有顶层 id不应把 event id 当成 response_id。
observed := observeUpstreamMessage(
state,
[]byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`),
startAt,
nowFn,
nil,
)
require.False(t, observed.terminal)
require.Equal(t, "", observed.responseID)
// terminal允许兜底用顶层 id用于兼容少数字段变体
observed = observeUpstreamMessage(
state,
[]byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`),
startAt,
nowFn,
nil,
)
require.True(t, observed.terminal)
require.Equal(t, "resp_fallback", observed.responseID)
}

View File

@@ -0,0 +1,752 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"sync"
"sync/atomic"
"testing"
"time"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
)
type passthroughTestFrame struct {
msgType coderws.MessageType
payload []byte
}
type passthroughTestFrameConn struct {
mu sync.Mutex
writes []passthroughTestFrame
readCh chan passthroughTestFrame
once sync.Once
}
type delayedReadFrameConn struct {
base FrameConn
firstDelay time.Duration
once sync.Once
}
type closeSpyFrameConn struct {
closeCalls atomic.Int32
}
func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn {
c := &passthroughTestFrameConn{
readCh: make(chan passthroughTestFrame, len(frames)+1),
}
for _, frame := range frames {
copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)}
c.readCh <- copied
}
if autoClose {
close(c.readCh)
}
return c
}
func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return coderws.MessageText, nil, ctx.Err()
case frame, ok := <-c.readCh:
if !ok {
return coderws.MessageText, nil, io.EOF
}
return frame.msgType, append([]byte(nil), frame.payload...), nil
}
}
func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
c.mu.Lock()
defer c.mu.Unlock()
c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)})
return nil
}
func (c *passthroughTestFrameConn) Close() error {
c.once.Do(func() {
defer func() { _ = recover() }()
close(c.readCh)
})
return nil
}
func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame {
c.mu.Lock()
defer c.mu.Unlock()
out := make([]passthroughTestFrame, len(c.writes))
copy(out, c.writes)
return out
}
func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.base == nil {
return coderws.MessageText, nil, io.EOF
}
c.once.Do(func() {
if c.firstDelay > 0 {
timer := time.NewTimer(c.firstDelay)
defer timer.Stop()
select {
case <-ctx.Done():
case <-timer.C:
}
}
})
return c.base.ReadFrame(ctx)
}
func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.base == nil {
return io.EOF
}
return c.base.WriteFrame(ctx, msgType, payload)
}
func (c *delayedReadFrameConn) Close() error {
if c == nil || c.base == nil {
return nil
}
return c.base.Close()
}
func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if ctx == nil {
ctx = context.Background()
}
<-ctx.Done()
return coderws.MessageText, nil, ctx.Err()
}
func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}
func (c *closeSpyFrameConn) Close() error {
if c != nil {
c.closeCalls.Add(1)
}
return nil
}
func (c *closeSpyFrameConn) CloseCalls() int32 {
if c == nil {
return 0
}
return c.closeCalls.Load()
}
func TestRelay_BasicRelayAndUsage(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
require.Equal(t, "gpt-5.3-codex", result.RequestModel)
require.Equal(t, "resp_123", result.RequestID)
require.Equal(t, "response.completed", result.TerminalEventType)
require.Equal(t, 7, result.Usage.InputTokens)
require.Equal(t, 3, result.Usage.OutputTokens)
require.Equal(t, 2, result.Usage.CacheReadInputTokens)
require.NotNil(t, result.FirstTokenMs)
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
require.Equal(t, int64(1), result.UpstreamToClientFrames)
require.Equal(t, int64(0), result.DroppedDownstreamFrames)
upstreamWrites := upstreamConn.Writes()
require.Len(t, upstreamWrites, 1)
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload))
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload))
}
func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
upstreamWrites := upstreamConn.Writes()
require.Len(t, upstreamWrites, 1)
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
require.Equal(t, firstPayload, upstreamWrites[0].payload)
}
func TestRelay_UpstreamDisconnect(t *testing.T) {
t.Parallel()
// 上游立即关闭EOF客户端不发送额外帧
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
// 上游 EOF 属于 disconnect标记为 graceful
require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect")
require.Equal(t, "gpt-4o", result.RequestModel)
}
func TestRelay_ClientDisconnect(t *testing.T) {
t.Parallel()
// 客户端立即关闭EOF上游阻塞读取直到 context 取消
clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态")
require.Equal(t, "client_disconnected", relayExit.Stage)
require.Equal(t, "gpt-4o", result.RequestModel)
}
func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, true)
upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`),
},
}, true)
upstreamConn := &delayedReadFrameConn{
base: upstreamBase,
firstDelay: 80 * time.Millisecond,
}
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
UpstreamDrainTimeout: 400 * time.Millisecond,
})
require.NotNil(t, relayExit)
require.Equal(t, "client_disconnected", relayExit.Stage)
require.Equal(t, "resp_drain", result.RequestID)
require.Equal(t, "response.completed", result.TerminalEventType)
require.Equal(t, 6, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 1, result.Usage.CacheReadInputTokens)
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
require.Equal(t, int64(0), result.UpstreamToClientFrames)
require.Equal(t, int64(1), result.DroppedDownstreamFrames)
}
func TestRelay_IdleTimeout(t *testing.T) {
t.Parallel()
// 客户端和上游都不发送帧idle timeout 应触发
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 使用快进时间来加速 idle timeout
now := time.Now()
callCount := 0
nowFn := func() time.Time {
callCount++
// 前几次调用返回正常时间(初始化阶段),之后快进
if callCount <= 5 {
return now
}
return now.Add(time.Hour) // 快进到超时
}
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
IdleTimeout: 2 * time.Second,
Now: nowFn,
})
require.NotNil(t, relayExit, "应因 idle timeout 退出")
require.Equal(t, "idle_timeout", relayExit.Stage)
require.Equal(t, "gpt-4o", result.RequestModel)
}
func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) {
t.Parallel()
clientConn := &closeSpyFrameConn{}
upstreamConn := &closeSpyFrameConn{}
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
now := time.Now()
callCount := 0
nowFn := func() time.Time {
callCount++
if callCount <= 5 {
return now
}
return now.Add(time.Hour)
}
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
IdleTimeout: 2 * time.Second,
Now: nowFn,
})
require.NotNil(t, relayExit, "应因 idle timeout 退出")
require.Equal(t, "idle_timeout", relayExit.Stage)
require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code")
require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1))
}
func TestRelay_NilConnections(t *testing.T) {
t.Parallel()
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx := context.Background()
t.Run("nil client conn", func(t *testing.T) {
upstreamConn := newPassthroughTestFrameConn(nil, true)
_, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{})
require.NotNil(t, relayExit)
require.Equal(t, "relay_init", relayExit.Stage)
require.Contains(t, relayExit.Err.Error(), "nil")
})
t.Run("nil upstream conn", func(t *testing.T) {
clientConn := newPassthroughTestFrameConn(nil, true)
_, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{})
require.NotNil(t, relayExit)
require.Equal(t, "relay_init", relayExit.Stage)
require.Contains(t, relayExit.Err.Error(), "nil")
})
}
func TestRelay_MultipleUpstreamMessages(t *testing.T) {
t.Parallel()
// 上游发送多个事件delta + completed验证多帧中继和 usage 聚合
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
require.Equal(t, "resp_multi", result.RequestID)
require.Equal(t, "response.completed", result.TerminalEventType)
require.Equal(t, 10, result.Usage.InputTokens)
require.Equal(t, 5, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
require.NotNil(t, result.FirstTokenMs)
// 验证所有 3 个上游帧都转发给了客户端
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 3)
}
func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
turns := make([]RelayTurnResult, 0, 2)
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
OnTurnComplete: func(turn RelayTurnResult) {
turns = append(turns, turn)
},
})
require.Nil(t, relayExit)
require.Len(t, turns, 2)
require.Equal(t, "resp_turn_1", turns[0].RequestID)
require.Equal(t, "response.completed", turns[0].TerminalEventType)
require.Equal(t, 2, turns[0].Usage.InputTokens)
require.Equal(t, 1, turns[0].Usage.OutputTokens)
require.Equal(t, "resp_turn_2", turns[1].RequestID)
require.Equal(t, "response.failed", turns[1].TerminalEventType)
require.Equal(t, 3, turns[1].Usage.InputTokens)
require.Equal(t, 4, turns[1].Usage.OutputTokens)
require.Equal(t, 5, result.Usage.InputTokens)
require.Equal(t, 5, result.Usage.OutputTokens)
}
func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
base := time.Unix(0, 0)
var nowTick atomic.Int64
nowFn := func() time.Time {
step := nowTick.Add(1)
return base.Add(time.Duration(step) * 5 * time.Millisecond)
}
var turn RelayTurnResult
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
Now: nowFn,
OnTurnComplete: func(current RelayTurnResult) {
turn = current
},
})
require.Nil(t, relayExit)
require.Equal(t, "resp_metric", turn.RequestID)
require.Equal(t, "response.completed", turn.TerminalEventType)
require.NotNil(t, turn.FirstTokenMs)
require.GreaterOrEqual(t, *turn.FirstTokenMs, 0)
require.Greater(t, turn.Duration.Milliseconds(), int64(0))
require.NotNil(t, result.FirstTokenMs)
require.Greater(t, result.Duration.Milliseconds(), int64(0))
}
func TestRelay_BinaryFramePassthrough(t *testing.T) {
t.Parallel()
// 验证 binary frame 被透传但不进行 usage 解析
binaryPayload := []byte{0x00, 0x01, 0x02, 0x03}
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageBinary,
payload: binaryPayload,
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
// binary frame 不解析 usage
require.Equal(t, 0, result.Usage.InputTokens)
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
require.Equal(t, binaryPayload, clientWrites[0].payload)
}
func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageBinary,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
require.Equal(t, 0, result.Usage.InputTokens)
require.Equal(t, "", result.RequestID)
require.Equal(t, "", result.TerminalEventType)
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
}
func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: errorEvent,
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
require.Equal(t, errorEvent, clientWrites[0].payload)
}
func TestRelay_PreservesFirstMessageType(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
FirstMessageType: coderws.MessageBinary,
})
require.Nil(t, relayExit)
upstreamWrites := upstreamConn.Writes()
require.Len(t, upstreamWrites, 1)
require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType)
require.Equal(t, firstPayload, upstreamWrites[0].payload)
}
func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) {
baseline := SnapshotMetrics().UsageParseFailureTotal
// 上游发送无效 JSON非 usage 格式),不应影响透传
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
// usage 解析失败,值为 0 但不影响透传
require.Equal(t, 0, result.Usage.InputTokens)
require.Equal(t, "response.completed", result.TerminalEventType)
// 帧仍然被转发
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1)
}
func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) {
t.Parallel()
// 上游连接立即关闭,首包写入失败
upstreamConn := newPassthroughTestFrameConn(nil, true)
_ = upstreamConn.Close()
// 覆盖 WriteFrame 使其返回错误
errConn := &errorOnWriteFrameConn{}
clientConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{})
require.NotNil(t, relayExit)
require.Equal(t, "write_upstream", relayExit.Stage)
}
func TestRelay_ContextCanceled(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
// 立即取消 context
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
// context 取消导致写首包失败
require.NotNil(t, relayExit)
}
func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
stages := make([]string, 0, 8)
var stagesMu sync.Mutex
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
OnTrace: func(event RelayTraceEvent) {
stagesMu.Lock()
stages = append(stages, event.Stage)
stagesMu.Unlock()
},
})
require.Nil(t, relayExit)
stagesMu.Lock()
capturedStages := append([]string(nil), stages...)
stagesMu.Unlock()
require.Contains(t, capturedStages, "relay_start")
require.Contains(t, capturedStages, "write_first_message_ok")
require.Contains(t, capturedStages, "first_exit")
require.Contains(t, capturedStages, "relay_complete")
}
func TestRelay_TraceEvents_IdleTimeout(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
now := time.Now()
callCount := 0
nowFn := func() time.Time {
callCount++
if callCount <= 5 {
return now
}
return now.Add(time.Hour)
}
stages := make([]string, 0, 8)
var stagesMu sync.Mutex
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
IdleTimeout: 2 * time.Second,
Now: nowFn,
OnTrace: func(event RelayTraceEvent) {
stagesMu.Lock()
stages = append(stages, event.Stage)
stagesMu.Unlock()
},
})
require.NotNil(t, relayExit)
require.Equal(t, "idle_timeout", relayExit.Stage)
stagesMu.Lock()
capturedStages := append([]string(nil), stages...)
stagesMu.Unlock()
require.Contains(t, capturedStages, "idle_timeout_triggered")
require.Contains(t, capturedStages, "relay_exit")
}
// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。
type errorOnWriteFrameConn struct{}
func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
<-ctx.Done()
return coderws.MessageText, nil, ctx.Err()
}
func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error {
return errors.New("write failed: connection refused")
}
func (c *errorOnWriteFrameConn) Close() error {
return nil
}

View File

@@ -0,0 +1,367 @@
package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type openAIWSClientFrameConn struct {
conn *coderws.Conn
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.conn == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Read(ctx)
}
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Write(ctx, msgType, payload)
}
func (c *openAIWSClientFrameConn) Close() error {
if c == nil || c.conn == nil {
return nil
}
_ = c.conn.Close(coderws.StatusNormalClosure, "")
_ = c.conn.CloseNow()
return nil
}
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
ctx context.Context,
c *gin.Context,
clientConn *coderws.Conn,
account *Account,
token string,
firstClientMessage []byte,
hooks *OpenAIWSIngressHooks,
wsDecision OpenAIWSProtocolDecision,
) error {
if s == nil {
return errors.New("service is nil")
}
if clientConn == nil {
return errors.New("client websocket is nil")
}
if account == nil {
return errors.New("account is nil")
}
if strings.TrimSpace(token) == "" {
return errors.New("token is empty")
}
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
account.ID,
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
openaiwsv2RelayMessageTypeName(coderws.MessageText),
len(firstClientMessage),
)
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
}
wsHost := "-"
wsPath := "-"
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
}
logOpenAIWSV2Passthrough(
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
account.ID,
wsHost,
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
isCodexCLI := false
if c != nil {
isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
}
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
isCodexCLI = true
}
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
dialer := s.getOpenAIWSPassthroughDialer()
if dialer == nil {
return errors.New("openai ws passthrough dialer is nil")
}
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
defer cancelDial()
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
if err != nil {
logOpenAIWSV2Passthrough(
"relay_dial_failed account_id=%d status_code=%d err=%s",
account.ID,
statusCode,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
)
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
}
defer func() {
_ = upstreamConn.Close()
}()
logOpenAIWSV2Passthrough(
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
account.ID,
statusCode,
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
)
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
if !ok {
return errors.New("openai ws passthrough upstream connection does not support frame relay")
}
completedTurns := atomic.Int32{}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
WriteTimeout: s.openAIWSWriteTimeout(),
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
FirstMessageType: coderws.MessageText,
OnUsageParseFailure: func(eventType string, usageRaw string) {
logOpenAIWSV2Passthrough(
"usage_parse_failed event_type=%s usage_raw=%s",
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
)
},
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
turnNo := int(completedTurns.Add(1))
turnResult := &OpenAIForwardResult{
RequestID: turn.RequestID,
Usage: OpenAIUsage{
InputTokens: turn.Usage.InputTokens,
OutputTokens: turn.Usage.OutputTokens,
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
},
Model: turn.RequestModel,
Stream: true,
OpenAIWSMode: true,
Duration: turn.Duration,
FirstTokenMs: turn.FirstTokenMs,
}
logOpenAIWSV2Passthrough(
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
account.ID,
turnNo,
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
turnResult.Duration.Milliseconds(),
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
turnResult.Usage.InputTokens,
turnResult.Usage.OutputTokens,
turnResult.Usage.CacheReadInputTokens,
)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turnNo, turnResult, nil)
}
},
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
logOpenAIWSV2Passthrough(
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
account.ID,
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
event.PayloadBytes,
event.Graceful,
event.WroteDownstream,
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
)
},
},
})
result := &OpenAIForwardResult{
RequestID: relayResult.RequestID,
Usage: OpenAIUsage{
InputTokens: relayResult.Usage.InputTokens,
OutputTokens: relayResult.Usage.OutputTokens,
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
},
Model: relayResult.RequestModel,
Stream: true,
OpenAIWSMode: true,
Duration: relayResult.Duration,
FirstTokenMs: relayResult.FirstTokenMs,
}
turnCount := int(completedTurns.Load())
if relayExit == nil {
logOpenAIWSV2Passthrough(
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
account.ID,
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
result.Duration.Milliseconds(),
relayResult.ClientToUpstreamFrames,
relayResult.UpstreamToClientFrames,
relayResult.DroppedDownstreamFrames,
turnCount,
)
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(1, result, nil)
}
return nil
}
logOpenAIWSV2Passthrough(
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
account.ID,
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
relayExit.WroteDownstream,
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
result.Duration.Milliseconds(),
relayResult.ClientToUpstreamFrames,
relayResult.UpstreamToClientFrames,
relayResult.DroppedDownstreamFrames,
turnCount,
)
relayErr := relayExit.Err
if relayExit.Stage == "idle_timeout" {
relayErr = NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"client websocket idle timeout",
relayErr,
)
}
turnErr := wrapOpenAIWSIngressTurnError(
relayExit.Stage,
relayErr,
relayExit.WroteDownstream,
)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turnCount+1, nil, turnErr)
}
return turnErr
}
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
err error,
statusCode int,
handshakeHeaders http.Header,
) error {
if err == nil {
return nil
}
wrappedErr := err
var dialErr *openAIWSDialError
if !errors.As(err, &dialErr) {
wrappedErr = &openAIWSDialError{
StatusCode: statusCode,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: err,
}
}
if errors.Is(err, context.Canceled) {
return err
}
if errors.Is(err, context.DeadlineExceeded) {
return NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket connect timeout",
wrappedErr,
)
}
if statusCode == http.StatusTooManyRequests {
return NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket is busy, please retry later",
wrappedErr,
)
}
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream websocket authentication failed",
wrappedErr,
)
}
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream websocket handshake rejected",
wrappedErr,
)
}
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
}
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
switch msgType {
case coderws.MessageText:
return "text"
case coderws.MessageBinary:
return "binary"
default:
return fmt.Sprintf("unknown(%d)", msgType)
}
}
func relayErrorText(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
if firstTokenMs == nil {
return -1
}
return *firstTokenMs
}
func logOpenAIWSV2Passthrough(format string, args ...any) {
logger.LegacyPrintf(
"service.openai_ws_v2",
"[OpenAI WS v2 passthrough] %s "+format,
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
)
}

View File

@@ -31,6 +31,10 @@ func (s *OpsService) GetDashboardOverview(ctx context.Context, filter *OpsDashbo
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
overview, err := s.opsRepo.GetDashboardOverview(ctx, filter) overview, err := s.opsRepo.GetDashboardOverview(ctx, filter)
if err != nil && shouldFallbackOpsPreagg(filter, err) {
rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw)
overview, err = s.opsRepo.GetDashboardOverview(ctx, rawFilter)
}
if err != nil { if err != nil {
if errors.Is(err, ErrOpsPreaggregatedNotPopulated) { if errors.Is(err, ErrOpsPreaggregatedNotPopulated) {
return nil, infraerrors.Conflict("OPS_PREAGG_NOT_READY", "Pre-aggregated ops metrics are not populated yet") return nil, infraerrors.Conflict("OPS_PREAGG_NOT_READY", "Pre-aggregated ops metrics are not populated yet")

View File

@@ -22,7 +22,14 @@ func (s *OpsService) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilt
if filter.StartTime.After(filter.EndTime) { if filter.StartTime.After(filter.EndTime) {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
} }
return s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds) filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
result, err := s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds)
if err != nil && shouldFallbackOpsPreagg(filter, err) {
rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw)
return s.opsRepo.GetErrorTrend(ctx, rawFilter, bucketSeconds)
}
return result, err
} }
func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) { func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) {
@@ -41,5 +48,12 @@ func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashbo
if filter.StartTime.After(filter.EndTime) { if filter.StartTime.After(filter.EndTime) {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
} }
return s.opsRepo.GetErrorDistribution(ctx, filter) filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
result, err := s.opsRepo.GetErrorDistribution(ctx, filter)
if err != nil && shouldFallbackOpsPreagg(filter, err) {
rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw)
return s.opsRepo.GetErrorDistribution(ctx, rawFilter)
}
return result, err
} }

View File

@@ -22,5 +22,12 @@ func (s *OpsService) GetLatencyHistogram(ctx context.Context, filter *OpsDashboa
if filter.StartTime.After(filter.EndTime) { if filter.StartTime.After(filter.EndTime) {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
} }
return s.opsRepo.GetLatencyHistogram(ctx, filter) filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
result, err := s.opsRepo.GetLatencyHistogram(ctx, filter)
if err != nil && shouldFallbackOpsPreagg(filter, err) {
rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw)
return s.opsRepo.GetLatencyHistogram(ctx, rawFilter)
}
return result, err
} }

View File

@@ -38,3 +38,18 @@ func (m OpsQueryMode) IsValid() bool {
return false return false
} }
} }
func shouldFallbackOpsPreagg(filter *OpsDashboardFilter, err error) bool {
return filter != nil &&
filter.QueryMode == OpsQueryModeAuto &&
errors.Is(err, ErrOpsPreaggregatedNotPopulated)
}
func cloneOpsFilterWithMode(filter *OpsDashboardFilter, mode OpsQueryMode) *OpsDashboardFilter {
if filter == nil {
return nil
}
cloned := *filter
cloned.QueryMode = mode
return &cloned
}

View File

@@ -0,0 +1,66 @@
//go:build unit
package service
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestShouldFallbackOpsPreagg(t *testing.T) {
preaggErr := ErrOpsPreaggregatedNotPopulated
otherErr := errors.New("some other error")
autoFilter := &OpsDashboardFilter{QueryMode: OpsQueryModeAuto}
rawFilter := &OpsDashboardFilter{QueryMode: OpsQueryModeRaw}
preaggFilter := &OpsDashboardFilter{QueryMode: OpsQueryModePreagg}
tests := []struct {
name string
filter *OpsDashboardFilter
err error
want bool
}{
{"auto mode + preagg error => fallback", autoFilter, preaggErr, true},
{"auto mode + other error => no fallback", autoFilter, otherErr, false},
{"auto mode + nil error => no fallback", autoFilter, nil, false},
{"raw mode + preagg error => no fallback", rawFilter, preaggErr, false},
{"preagg mode + preagg error => no fallback", preaggFilter, preaggErr, false},
{"nil filter => no fallback", nil, preaggErr, false},
{"wrapped preagg error => fallback", autoFilter, errors.Join(preaggErr, otherErr), true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := shouldFallbackOpsPreagg(tc.filter, tc.err)
require.Equal(t, tc.want, got)
})
}
}
func TestCloneOpsFilterWithMode(t *testing.T) {
t.Run("nil filter returns nil", func(t *testing.T) {
require.Nil(t, cloneOpsFilterWithMode(nil, OpsQueryModeRaw))
})
t.Run("cloned filter has new mode", func(t *testing.T) {
groupID := int64(42)
original := &OpsDashboardFilter{
StartTime: time.Now(),
EndTime: time.Now().Add(time.Hour),
Platform: "anthropic",
GroupID: &groupID,
QueryMode: OpsQueryModeAuto,
}
cloned := cloneOpsFilterWithMode(original, OpsQueryModeRaw)
require.Equal(t, OpsQueryModeRaw, cloned.QueryMode)
require.Equal(t, OpsQueryModeAuto, original.QueryMode, "original should not be modified")
require.Equal(t, original.Platform, cloned.Platform)
require.Equal(t, original.StartTime, cloned.StartTime)
require.Equal(t, original.GroupID, cloned.GroupID)
})
}

View File

@@ -22,5 +22,13 @@ func (s *OpsService) GetThroughputTrend(ctx context.Context, filter *OpsDashboar
if filter.StartTime.After(filter.EndTime) { if filter.StartTime.After(filter.EndTime) {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
} }
return s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds)
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
result, err := s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds)
if err != nil && shouldFallbackOpsPreagg(filter, err) {
rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw)
return s.opsRepo.GetThroughputTrend(ctx, rawFilter, bucketSeconds)
}
return result, err
} }

View File

@@ -676,7 +676,17 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
} }
} }
// 没有重置时间使用默认5分钟 // Anthropic 平台:没有限流重置时间的 429 可能是非真实限流(如 Extra usage required
// 不标记账号限流状态,直接透传错误给客户端
if account.Platform == PlatformAnthropic {
slog.Warn("rate_limit_429_no_reset_time_skipped",
"account_id", account.ID,
"platform", account.Platform,
"reason", "no rate limit reset time in headers, likely not a real rate limit")
return
}
// 其他平台没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute) resetAt := time.Now().Add(5 * time.Minute)
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m") slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
@@ -1081,6 +1091,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
if !account.IsTempUnschedulableEnabled() { if !account.IsTempUnschedulableEnabled() {
return false return false
} }
// 401 首次命中可临时不可调度(给 token 刷新窗口);
// 若历史上已因 401 进入过临时不可调度,则本次应升级为 error返回 false 交由默认错误逻辑处理)。
if statusCode == http.StatusUnauthorized {
reason := account.TempUnschedulableReason
// 缓存可能没有 reason从 DB 回退读取
if reason == "" {
if dbAcc, err := s.accountRepo.GetByID(ctx, account.ID); err == nil && dbAcc != nil {
reason = dbAcc.TempUnschedulableReason
}
}
if wasTempUnschedByStatusCode(reason, statusCode) {
slog.Info("401_escalated_to_error", "account_id", account.ID,
"reason", "previous temp-unschedulable was also 401")
return false
}
}
rules := account.GetTempUnschedulableRules() rules := account.GetTempUnschedulableRules()
if len(rules) == 0 { if len(rules) == 0 {
return false return false
@@ -1112,6 +1138,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
return false return false
} }
func wasTempUnschedByStatusCode(reason string, statusCode int) bool {
if statusCode <= 0 {
return false
}
reason = strings.TrimSpace(reason)
if reason == "" {
return false
}
var state TempUnschedState
if err := json.Unmarshal([]byte(reason), &state); err != nil {
return false
}
return state.StatusCode == statusCode
}
func matchTempUnschedKeyword(bodyLower string, keywords []string) string { func matchTempUnschedKeyword(bodyLower string, keywords []string) string {
if bodyLower == "" { if bodyLower == "" {
return "" return ""

View File

@@ -0,0 +1,119 @@
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// dbFallbackRepoStub extends errorPolicyRepoStub with a configurable DB account
// returned by GetByID, simulating cache miss + DB fallback.
type dbFallbackRepoStub struct {
errorPolicyRepoStub
dbAccount *Account // returned by GetByID when non-nil
}
func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
if r.dbAccount != nil && r.dbAccount.ID == id {
return r.dbAccount, nil
}
return nil, nil // not found, no error
}
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "", // cache miss — reason is empty
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
}
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason,
// DB also has no previous 401 record → should NOT escalate (first hit → temp unscheduled).
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 21,
TempUnschedulableReason: "", // DB also empty
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 21,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with no DB record should temp-unschedule")
}
func TestCheckErrorPolicy_401_DBFallback_DBError_FirstHit(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason,
// DB lookup returns nil (not found) → should treat as first hit → temp unscheduled.
repo := &dbFallbackRepoStub{
dbAccount: nil, // GetByID returns nil, nil
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 22,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with DB not found should temp-unschedule")
}

View File

@@ -0,0 +1,123 @@
package service
import (
"encoding/json"
"fmt"
"regexp"
"strings"
)
var registrationEmailDomainPattern = regexp.MustCompile(
`^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)+$`,
)
// RegistrationEmailSuffix extracts normalized suffix in "@domain" form.
func RegistrationEmailSuffix(email string) string {
_, domain, ok := splitEmailForPolicy(email)
if !ok {
return ""
}
return "@" + domain
}
// IsRegistrationEmailSuffixAllowed checks whether an email is allowed by suffix whitelist.
// Empty whitelist means allow all.
func IsRegistrationEmailSuffixAllowed(email string, whitelist []string) bool {
if len(whitelist) == 0 {
return true
}
suffix := RegistrationEmailSuffix(email)
if suffix == "" {
return false
}
for _, allowed := range whitelist {
if suffix == allowed {
return true
}
}
return false
}
// NormalizeRegistrationEmailSuffixWhitelist normalizes and validates suffix whitelist items.
func NormalizeRegistrationEmailSuffixWhitelist(raw []string) ([]string, error) {
return normalizeRegistrationEmailSuffixWhitelist(raw, true)
}
// ParseRegistrationEmailSuffixWhitelist parses persisted JSON into normalized suffixes.
// Invalid entries are ignored to keep old misconfigurations from breaking runtime reads.
func ParseRegistrationEmailSuffixWhitelist(raw string) []string {
raw = strings.TrimSpace(raw)
if raw == "" {
return []string{}
}
var items []string
if err := json.Unmarshal([]byte(raw), &items); err != nil {
return []string{}
}
normalized, _ := normalizeRegistrationEmailSuffixWhitelist(items, false)
if len(normalized) == 0 {
return []string{}
}
return normalized
}
func normalizeRegistrationEmailSuffixWhitelist(raw []string, strict bool) ([]string, error) {
if len(raw) == 0 {
return nil, nil
}
seen := make(map[string]struct{}, len(raw))
out := make([]string, 0, len(raw))
for _, item := range raw {
normalized, err := normalizeRegistrationEmailSuffix(item)
if err != nil {
if strict {
return nil, err
}
continue
}
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
out = append(out, normalized)
}
if len(out) == 0 {
return nil, nil
}
return out, nil
}
func normalizeRegistrationEmailSuffix(raw string) (string, error) {
value := strings.ToLower(strings.TrimSpace(raw))
if value == "" {
return "", nil
}
domain := value
if strings.Contains(value, "@") {
if !strings.HasPrefix(value, "@") || strings.Count(value, "@") != 1 {
return "", fmt.Errorf("invalid email suffix: %q", raw)
}
domain = strings.TrimPrefix(value, "@")
}
if domain == "" || strings.Contains(domain, "@") || !registrationEmailDomainPattern.MatchString(domain) {
return "", fmt.Errorf("invalid email suffix: %q", raw)
}
return "@" + domain, nil
}
func splitEmailForPolicy(raw string) (local string, domain string, ok bool) {
email := strings.ToLower(strings.TrimSpace(raw))
local, domain, found := strings.Cut(email, "@")
if !found || local == "" || domain == "" || strings.Contains(domain, "@") {
return "", "", false
}
return local, domain, true
}

View File

@@ -0,0 +1,31 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNormalizeRegistrationEmailSuffixWhitelist(t *testing.T) {
got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar "})
require.NoError(t, err)
require.Equal(t, []string{"@example.com", "@foo.bar"}, got)
}
func TestNormalizeRegistrationEmailSuffixWhitelist_Invalid(t *testing.T) {
_, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"@invalid_domain"})
require.Error(t, err)
}
func TestParseRegistrationEmailSuffixWhitelist(t *testing.T) {
got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","@invalid_domain"]`)
require.Equal(t, []string{"@example.com", "@foo.bar"}, got)
}
func TestIsRegistrationEmailSuffixAllowed(t *testing.T) {
require.True(t, IsRegistrationEmailSuffixAllowed("user@example.com", []string{"@example.com"}))
require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.example.com", []string{"@example.com"}))
require.True(t, IsRegistrationEmailSuffixAllowed("user@any.com", []string{}))
}

View File

@@ -0,0 +1,51 @@
package service
import (
"context"
"time"
)
// ScheduledTestPlan represents a scheduled test plan domain model.
type ScheduledTestPlan struct {
ID int64 `json:"id"`
AccountID int64 `json:"account_id"`
ModelID string `json:"model_id"`
CronExpression string `json:"cron_expression"`
Enabled bool `json:"enabled"`
MaxResults int `json:"max_results"`
LastRunAt *time.Time `json:"last_run_at"`
NextRunAt *time.Time `json:"next_run_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ScheduledTestResult represents a single test execution result.
type ScheduledTestResult struct {
ID int64 `json:"id"`
PlanID int64 `json:"plan_id"`
Status string `json:"status"`
ResponseText string `json:"response_text"`
ErrorMessage string `json:"error_message"`
LatencyMs int64 `json:"latency_ms"`
StartedAt time.Time `json:"started_at"`
FinishedAt time.Time `json:"finished_at"`
CreatedAt time.Time `json:"created_at"`
}
// ScheduledTestPlanRepository defines the data access interface for test plans.
type ScheduledTestPlanRepository interface {
Create(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error)
GetByID(ctx context.Context, id int64) (*ScheduledTestPlan, error)
ListByAccountID(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error)
ListDue(ctx context.Context, now time.Time) ([]*ScheduledTestPlan, error)
Update(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error)
Delete(ctx context.Context, id int64) error
UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error
}
// ScheduledTestResultRepository defines the data access interface for test results.
type ScheduledTestResultRepository interface {
Create(ctx context.Context, result *ScheduledTestResult) (*ScheduledTestResult, error)
ListByPlanID(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error)
PruneOldResults(ctx context.Context, planID int64, keepCount int) error
}

View File

@@ -0,0 +1,139 @@
package service
import (
"context"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/robfig/cron/v3"
)
const scheduledTestDefaultMaxWorkers = 10
// ScheduledTestRunnerService periodically scans due test plans and executes them.
type ScheduledTestRunnerService struct {
planRepo ScheduledTestPlanRepository
scheduledSvc *ScheduledTestService
accountTestSvc *AccountTestService
cfg *config.Config
cron *cron.Cron
startOnce sync.Once
stopOnce sync.Once
}
// NewScheduledTestRunnerService creates a new runner.
func NewScheduledTestRunnerService(
planRepo ScheduledTestPlanRepository,
scheduledSvc *ScheduledTestService,
accountTestSvc *AccountTestService,
cfg *config.Config,
) *ScheduledTestRunnerService {
return &ScheduledTestRunnerService{
planRepo: planRepo,
scheduledSvc: scheduledSvc,
accountTestSvc: accountTestSvc,
cfg: cfg,
}
}
// Start begins the cron ticker (every minute).
func (s *ScheduledTestRunnerService) Start() {
if s == nil {
return
}
s.startOnce.Do(func() {
loc := time.Local
if s.cfg != nil {
if parsed, err := time.LoadLocation(s.cfg.Timezone); err == nil && parsed != nil {
loc = parsed
}
}
c := cron.New(cron.WithParser(scheduledTestCronParser), cron.WithLocation(loc))
_, err := c.AddFunc("* * * * *", func() { s.runScheduled() })
if err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] not started (invalid schedule): %v", err)
return
}
s.cron = c
s.cron.Start()
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] started (tick=every minute)")
})
}
// Stop gracefully shuts down the cron scheduler.
func (s *ScheduledTestRunnerService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.cron != nil {
ctx := s.cron.Stop()
select {
case <-ctx.Done():
case <-time.After(3 * time.Second):
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] cron stop timed out")
}
}
})
}
func (s *ScheduledTestRunnerService) runScheduled() {
// Delay 10s so execution lands at ~:10 of each minute instead of :00.
time.Sleep(10 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
now := time.Now()
plans, err := s.planRepo.ListDue(ctx, now)
if err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] ListDue error: %v", err)
return
}
if len(plans) == 0 {
return
}
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] found %d due plans", len(plans))
sem := make(chan struct{}, scheduledTestDefaultMaxWorkers)
var wg sync.WaitGroup
for _, plan := range plans {
sem <- struct{}{}
wg.Add(1)
go func(p *ScheduledTestPlan) {
defer wg.Done()
defer func() { <-sem }()
s.runOnePlan(ctx, p)
}(plan)
}
wg.Wait()
}
func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *ScheduledTestPlan) {
result, err := s.accountTestSvc.RunTestBackground(ctx, plan.AccountID, plan.ModelID)
if err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d RunTestBackground error: %v", plan.ID, err)
return
}
if err := s.scheduledSvc.SaveResult(ctx, plan.ID, plan.MaxResults, result); err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d SaveResult error: %v", plan.ID, err)
}
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
if err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d computeNextRun error: %v", plan.ID, err)
return
}
if err := s.planRepo.UpdateAfterRun(ctx, plan.ID, time.Now(), nextRun); err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d UpdateAfterRun error: %v", plan.ID, err)
}
}

View File

@@ -0,0 +1,94 @@
package service
import (
"context"
"fmt"
"time"
"github.com/robfig/cron/v3"
)
var scheduledTestCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
// ScheduledTestService provides CRUD operations for scheduled test plans and results.
type ScheduledTestService struct {
planRepo ScheduledTestPlanRepository
resultRepo ScheduledTestResultRepository
}
// NewScheduledTestService creates a new ScheduledTestService.
func NewScheduledTestService(
planRepo ScheduledTestPlanRepository,
resultRepo ScheduledTestResultRepository,
) *ScheduledTestService {
return &ScheduledTestService{
planRepo: planRepo,
resultRepo: resultRepo,
}
}
// CreatePlan validates the cron expression, computes next_run_at, and persists the plan.
func (s *ScheduledTestService) CreatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) {
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
if err != nil {
return nil, fmt.Errorf("invalid cron expression: %w", err)
}
plan.NextRunAt = &nextRun
if plan.MaxResults <= 0 {
plan.MaxResults = 50
}
return s.planRepo.Create(ctx, plan)
}
// GetPlan retrieves a plan by ID.
func (s *ScheduledTestService) GetPlan(ctx context.Context, id int64) (*ScheduledTestPlan, error) {
return s.planRepo.GetByID(ctx, id)
}
// ListPlansByAccount returns all plans for a given account.
func (s *ScheduledTestService) ListPlansByAccount(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error) {
return s.planRepo.ListByAccountID(ctx, accountID)
}
// UpdatePlan validates cron and updates the plan.
func (s *ScheduledTestService) UpdatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) {
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
if err != nil {
return nil, fmt.Errorf("invalid cron expression: %w", err)
}
plan.NextRunAt = &nextRun
return s.planRepo.Update(ctx, plan)
}
// DeletePlan removes a plan and its results (via CASCADE).
func (s *ScheduledTestService) DeletePlan(ctx context.Context, id int64) error {
return s.planRepo.Delete(ctx, id)
}
// ListResults returns the most recent results for a plan.
func (s *ScheduledTestService) ListResults(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error) {
if limit <= 0 {
limit = 50
}
return s.resultRepo.ListByPlanID(ctx, planID, limit)
}
// SaveResult inserts a result and prunes old entries beyond maxResults.
func (s *ScheduledTestService) SaveResult(ctx context.Context, planID int64, maxResults int, result *ScheduledTestResult) error {
result.PlanID = planID
if _, err := s.resultRepo.Create(ctx, result); err != nil {
return err
}
return s.resultRepo.PruneOldResults(ctx, planID, maxResults)
}
func computeNextRun(cronExpr string, from time.Time) (time.Time, error) {
sched, err := scheduledTestCronParser.Parse(cronExpr)
if err != nil {
return time.Time{}, err
}
return sched.Next(from), nil
}

View File

@@ -108,6 +108,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys := []string{ keys := []string{
SettingKeyRegistrationEnabled, SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled, SettingKeyEmailVerifyEnabled,
SettingKeyRegistrationEmailSuffixWhitelist,
SettingKeyPromoCodeEnabled, SettingKeyPromoCodeEnabled,
SettingKeyPasswordResetEnabled, SettingKeyPasswordResetEnabled,
SettingKeyInvitationCodeEnabled, SettingKeyInvitationCodeEnabled,
@@ -144,29 +145,33 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
// Password reset requires email verification to be enabled // Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true" passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true"
registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist(
settings[SettingKeyRegistrationEmailSuffixWhitelist],
)
return &PublicSettings{ return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled, EmailVerifyEnabled: emailVerifyEnabled,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist,
PasswordResetEnabled: passwordResetEnabled, PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", PasswordResetEnabled: passwordResetEnabled,
TotpEnabled: settings[SettingKeyTotpEnabled] == "true", InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteLogo: settings[SettingKeySiteLogo], SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), SiteLogo: settings[SettingKeySiteLogo],
APIBaseURL: settings[SettingKeyAPIBaseURL], SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ContactInfo: settings[SettingKeyContactInfo], APIBaseURL: settings[SettingKeyAPIBaseURL],
DocURL: settings[SettingKeyDocURL], ContactInfo: settings[SettingKeyContactInfo],
HomeContent: settings[SettingKeyHomeContent], DocURL: settings[SettingKeyDocURL],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", HomeContent: settings[SettingKeyHomeContent],
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
CustomMenuItems: settings[SettingKeyCustomMenuItems], SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
LinuxDoOAuthEnabled: linuxDoEnabled, CustomMenuItems: settings[SettingKeyCustomMenuItems],
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil }, nil
} }
@@ -196,51 +201,53 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
// Return a struct that matches the frontend's expected format // Return a struct that matches the frontend's expected format
return &struct { return &struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"` TotpEnabled bool `json:"totp_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` TurnstileEnabled bool `json:"turnstile_enabled"`
SiteName string `json:"site_name"` TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
SiteLogo string `json:"site_logo,omitempty"` SiteName string `json:"site_name"`
SiteSubtitle string `json:"site_subtitle,omitempty"` SiteLogo string `json:"site_logo,omitempty"`
APIBaseURL string `json:"api_base_url,omitempty"` SiteSubtitle string `json:"site_subtitle,omitempty"`
ContactInfo string `json:"contact_info,omitempty"` APIBaseURL string `json:"api_base_url,omitempty"`
DocURL string `json:"doc_url,omitempty"` ContactInfo string `json:"contact_info,omitempty"`
HomeContent string `json:"home_content,omitempty"` DocURL string `json:"doc_url,omitempty"`
HideCcsImportButton bool `json:"hide_ccs_import_button"` HomeContent string `json:"home_content,omitempty"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"` SoraClientEnabled bool `json:"sora_client_enabled"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` CustomMenuItems json.RawMessage `json:"custom_menu_items"`
Version string `json:"version,omitempty"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
}{ }{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PasswordResetEnabled: settings.PasswordResetEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
TotpEnabled: settings.TotpEnabled, InvitationCodeEnabled: settings.InvitationCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled, TotpEnabled: settings.TotpEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileEnabled: settings.TurnstileEnabled,
SiteName: settings.SiteName, TurnstileSiteKey: settings.TurnstileSiteKey,
SiteLogo: settings.SiteLogo, SiteName: settings.SiteName,
SiteSubtitle: settings.SiteSubtitle, SiteLogo: settings.SiteLogo,
APIBaseURL: settings.APIBaseURL, SiteSubtitle: settings.SiteSubtitle,
ContactInfo: settings.ContactInfo, APIBaseURL: settings.APIBaseURL,
DocURL: settings.DocURL, ContactInfo: settings.ContactInfo,
HomeContent: settings.HomeContent, DocURL: settings.DocURL,
HideCcsImportButton: settings.HideCcsImportButton, HomeContent: settings.HomeContent,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
SoraClientEnabled: settings.SoraClientEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), SoraClientEnabled: settings.SoraClientEnabled,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
Version: s.version, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
}, nil }, nil
} }
@@ -356,12 +363,25 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
return err return err
} }
normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist)
if err != nil {
return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
}
if normalizedWhitelist == nil {
normalizedWhitelist = []string{}
}
settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist
updates := make(map[string]string) updates := make(map[string]string)
// 注册设置 // 注册设置
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist)
if err != nil {
return fmt.Errorf("marshal registration email suffix whitelist: %w", err)
}
updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled) updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled) updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
@@ -514,6 +534,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
return value == "true" return value == "true"
} }
// GetRegistrationEmailSuffixWhitelist returns normalized registration email suffix whitelist.
func (s *SettingService) GetRegistrationEmailSuffixWhitelist(ctx context.Context) []string {
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEmailSuffixWhitelist)
if err != nil {
return []string{}
}
return ParseRegistrationEmailSuffixWhitelist(value)
}
// IsPromoCodeEnabled 检查是否启用优惠码功能 // IsPromoCodeEnabled 检查是否启用优惠码功能
func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool { func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled) value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled)
@@ -617,20 +646,21 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置 // 初始化默认设置
defaults := map[string]string{ defaults := map[string]string{
SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false", SettingKeyEmailVerifyEnabled: "false",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 SettingKeyRegistrationEmailSuffixWhitelist: "[]",
SettingKeySiteName: "Sub2API", SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteLogo: "", SettingKeySiteName: "Sub2API",
SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionURL: "", SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeySoraClientEnabled: "false", SettingKeyPurchaseSubscriptionURL: "",
SettingKeyCustomMenuItems: "[]", SettingKeySoraClientEnabled: "false",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyCustomMenuItems: "[]",
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultSubscriptions: "[]", SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySMTPPort: "587", SettingKeyDefaultSubscriptions: "[]",
SettingKeySMTPUseTLS: "false", SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
// Model fallback defaults // Model fallback defaults
SettingKeyEnableModelFallback: "false", SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
@@ -661,33 +691,34 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
result := &SystemSettings{ result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled, EmailVerifyEnabled: emailVerifyEnabled,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 RegistrationEmailSuffixWhitelist: ParseRegistrationEmailSuffixWhitelist(settings[SettingKeyRegistrationEmailSuffixWhitelist]),
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
TotpEnabled: settings[SettingKeyTotpEnabled] == "true", InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost], TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
SMTPUsername: settings[SettingKeySMTPUsername], SMTPHost: settings[SettingKeySMTPHost],
SMTPFrom: settings[SettingKeySMTPFrom], SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFromName: settings[SettingKeySMTPFromName], SMTPFrom: settings[SettingKeySMTPFrom],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", SMTPFromName: settings[SettingKeySMTPFromName],
SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "", SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
SiteLogo: settings[SettingKeySiteLogo], SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), SiteLogo: settings[SettingKeySiteLogo],
APIBaseURL: settings[SettingKeyAPIBaseURL], SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ContactInfo: settings[SettingKeyContactInfo], APIBaseURL: settings[SettingKeyAPIBaseURL],
DocURL: settings[SettingKeyDocURL], ContactInfo: settings[SettingKeyContactInfo],
HomeContent: settings[SettingKeyHomeContent], DocURL: settings[SettingKeyDocURL],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", HomeContent: settings[SettingKeyHomeContent],
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
CustomMenuItems: settings[SettingKeyCustomMenuItems], SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
} }
// 解析整数类型 // 解析整数类型

View File

@@ -0,0 +1,64 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type settingPublicRepoStub struct {
values map[string]string
}
func (s *settingPublicRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *settingPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *settingPublicRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *settingPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *settingPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *settingPublicRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelist(t *testing.T) {
repo := &settingPublicRepoStub{
values: map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","@invalid_domain",""]`,
},
}
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetPublicSettings(context.Background())
require.NoError(t, err)
require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist)
}

View File

@@ -172,6 +172,28 @@ func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGrou
require.Nil(t, repo.updates) require.Nil(t, repo.updates)
} }
func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Normalized(t *testing.T) {
repo := &settingUpdateRepoStub{}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{
RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar "},
})
require.NoError(t, err)
require.Equal(t, `["@example.com","@foo.bar"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist])
}
func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Invalid(t *testing.T) {
repo := &settingUpdateRepoStub{}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{
RegistrationEmailSuffixWhitelist: []string{"@invalid_domain"},
})
require.Error(t, err)
require.Equal(t, "INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", infraerrors.Reason(err))
}
func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) { func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) {
got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`) got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`)
require.Equal(t, []DefaultSubscriptionSetting{ require.Equal(t, []DefaultSubscriptionSetting{

View File

@@ -1,12 +1,13 @@
package service package service
type SystemSettings struct { type SystemSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
PromoCodeEnabled bool RegistrationEmailSuffixWhitelist []string
PasswordResetEnabled bool PromoCodeEnabled bool
InvitationCodeEnabled bool PasswordResetEnabled bool
TotpEnabled bool // TOTP 双因素认证 InvitationCodeEnabled bool
TotpEnabled bool // TOTP 双因素认证
SMTPHost string SMTPHost string
SMTPPort int SMTPPort int
@@ -76,22 +77,23 @@ type DefaultSubscriptionSetting struct {
} }
type PublicSettings struct { type PublicSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
PromoCodeEnabled bool RegistrationEmailSuffixWhitelist []string
PasswordResetEnabled bool PromoCodeEnabled bool
InvitationCodeEnabled bool PasswordResetEnabled bool
TotpEnabled bool // TOTP 双因素认证 InvitationCodeEnabled bool
TurnstileEnabled bool TotpEnabled bool // TOTP 双因素认证
TurnstileSiteKey string TurnstileEnabled bool
SiteName string TurnstileSiteKey string
SiteLogo string SiteName string
SiteSubtitle string SiteLogo string
APIBaseURL string SiteSubtitle string
ContactInfo string APIBaseURL string
DocURL string ContactInfo string
HomeContent string DocURL string
HideCcsImportButton bool HomeContent string
HideCcsImportButton bool
PurchaseSubscriptionEnabled bool PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string PurchaseSubscriptionURL string

View File

@@ -22,6 +22,10 @@ type UserListFilters struct {
Role string // User role filter Role string // User role filter
Search string // Search in email, username Search string // Search in email, username
Attributes map[int64]string // Custom attribute filters: attributeID -> value Attributes map[int64]string // Custom attribute filters: attributeID -> value
// IncludeSubscriptions controls whether ListWithFilters should load active subscriptions.
// For large datasets this can be expensive; admin list pages should enable it on demand.
// nil means not specified (default: load subscriptions for backward compatibility).
IncludeSubscriptions *bool
} }
type UserRepository interface { type UserRepository interface {

View File

@@ -274,6 +274,26 @@ func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Co
return svc return svc
} }
// ProvideScheduledTestService creates ScheduledTestService.
func ProvideScheduledTestService(
planRepo ScheduledTestPlanRepository,
resultRepo ScheduledTestResultRepository,
) *ScheduledTestService {
return NewScheduledTestService(planRepo, resultRepo)
}
// ProvideScheduledTestRunnerService creates and starts ScheduledTestRunnerService.
func ProvideScheduledTestRunnerService(
planRepo ScheduledTestPlanRepository,
scheduledSvc *ScheduledTestService,
accountTestSvc *AccountTestService,
cfg *config.Config,
) *ScheduledTestRunnerService {
svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, cfg)
svc.Start()
return svc
}
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService. // ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
func ProvideOpsScheduledReportService( func ProvideOpsScheduledReportService(
opsService *OpsService, opsService *OpsService,
@@ -380,4 +400,6 @@ var ProviderSet = wire.NewSet(
ProvideIdempotencyCoordinator, ProvideIdempotencyCoordinator,
ProvideSystemOperationLockService, ProvideSystemOperationLockService,
ProvideIdempotencyCleanupService, ProvideIdempotencyCleanupService,
ProvideScheduledTestService,
ProvideScheduledTestRunnerService,
) )

View File

@@ -0,0 +1,33 @@
-- Improve admin fuzzy-search performance on large datasets.
-- Best effort:
-- 1) try enabling pg_trgm
-- 2) only create trigram indexes when extension is available
DO $$
BEGIN
BEGIN
CREATE EXTENSION IF NOT EXISTS pg_trgm;
EXCEPTION
WHEN OTHERS THEN
RAISE NOTICE 'pg_trgm extension not created: %', SQLERRM;
END;
IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_trgm') THEN
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_email_trgm
ON users USING gin (email gin_trgm_ops)';
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_username_trgm
ON users USING gin (username gin_trgm_ops)';
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_notes_trgm
ON users USING gin (notes gin_trgm_ops)';
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_accounts_name_trgm
ON accounts USING gin (name gin_trgm_ops)';
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_api_keys_key_trgm
ON api_keys USING gin ("key" gin_trgm_ops)';
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_api_keys_name_trgm
ON api_keys USING gin (name gin_trgm_ops)';
ELSE
RAISE NOTICE 'skip trigram indexes because pg_trgm is unavailable';
END IF;
END
$$;

View File

@@ -0,0 +1,30 @@
-- 066_add_scheduled_test_tables.sql
-- Scheduled account test plans and results
CREATE TABLE IF NOT EXISTS scheduled_test_plans (
id BIGSERIAL PRIMARY KEY,
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
model_id VARCHAR(100) NOT NULL DEFAULT '',
cron_expression VARCHAR(100) NOT NULL DEFAULT '*/30 * * * *',
enabled BOOLEAN NOT NULL DEFAULT true,
max_results INT NOT NULL DEFAULT 50,
last_run_at TIMESTAMPTZ,
next_run_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_stp_account_id ON scheduled_test_plans(account_id);
CREATE INDEX IF NOT EXISTS idx_stp_enabled_next_run ON scheduled_test_plans(enabled, next_run_at) WHERE enabled = true;
CREATE TABLE IF NOT EXISTS scheduled_test_results (
id BIGSERIAL PRIMARY KEY,
plan_id BIGINT NOT NULL REFERENCES scheduled_test_plans(id) ON DELETE CASCADE,
status VARCHAR(20) NOT NULL DEFAULT 'success',
response_text TEXT NOT NULL DEFAULT '',
error_message TEXT NOT NULL DEFAULT '',
latency_ms BIGINT NOT NULL DEFAULT 0,
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
finished_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_str_plan_created ON scheduled_test_results(plan_id, created_at DESC);

View File

@@ -1,12 +0,0 @@
#!/usr/bin/env bash
# 本地构建镜像的快速脚本,避免在命令行反复输入构建参数。
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
docker build -t sub2api:latest \
--build-arg GOPROXY=https://goproxy.cn,direct \
--build-arg GOSUMDB=sum.golang.google.cn \
-f "${SCRIPT_DIR}/Dockerfile" \
"${SCRIPT_DIR}"

View File

@@ -112,7 +112,7 @@ POSTGRES_DB=sub2api
DATABASE_PORT=5432 DATABASE_PORT=5432
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# PostgreSQL 服务端参数(可选;主要用于 deploy/docker-compose-aicodex.yml # PostgreSQL 服务端参数(可选)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# POSTGRES_MAX_CONNECTIONSPostgreSQL 服务端允许的最大连接数。 # POSTGRES_MAX_CONNECTIONSPostgreSQL 服务端允许的最大连接数。
# 必须 >=(所有 Sub2API 实例的 DATABASE_MAX_OPEN_CONNS 之和)+ 预留余量(例如 20%)。 # 必须 >=(所有 Sub2API 实例的 DATABASE_MAX_OPEN_CONNS 之和)+ 预留余量(例如 20%)。
@@ -163,7 +163,7 @@ REDIS_PORT=6379
# Leave empty for no password (default for local development) # Leave empty for no password (default for local development)
REDIS_PASSWORD= REDIS_PASSWORD=
REDIS_DB=0 REDIS_DB=0
# Redis 服务端最大客户端连接数(可选;主要用于 deploy/docker-compose-aicodex.yml # Redis 服务端最大客户端连接数(可选)
REDIS_MAXCLIENTS=50000 REDIS_MAXCLIENTS=50000
# Redis 连接池大小(默认 1024 # Redis 连接池大小(默认 1024
REDIS_POOL_SIZE=4096 REDIS_POOL_SIZE=4096

View File

@@ -209,8 +209,9 @@ gateway:
openai_ws: openai_ws:
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。 # 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
mode_router_v2_enabled: false mode_router_v2_enabled: false
# ingress 默认模式off|shared|dedicated(仅 mode_router_v2_enabled=true 生效) # ingress 默认模式off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效)
ingress_mode_default: shared # 兼容旧值shared/dedicated 会按 ctx_pool 处理。
ingress_mode_default: ctx_pool
# 全局总开关,默认 true关闭时所有请求保持原有 HTTP/SSE 路由 # 全局总开关,默认 true关闭时所有请求保持原有 HTTP/SSE 路由
enabled: true enabled: true
# 按账号类型细分开关 # 按账号类型细分开关

View File

@@ -1,212 +0,0 @@
# =============================================================================
# Sub2API Docker Compose Test Configuration (Local Build)
# =============================================================================
# Quick Start:
# 1. Copy .env.example to .env and configure
# 2. docker-compose -f docker-compose-test.yml up -d --build
# 3. Check logs: docker-compose -f docker-compose-test.yml logs -f sub2api
# 4. Access: http://localhost:8080
#
# This configuration builds the image from source (Dockerfile in project root).
# All configuration is done via environment variables.
# No Setup Wizard needed - the system auto-initializes on first run.
# =============================================================================
services:
# ===========================================================================
# Sub2API Application
# ===========================================================================
sub2api:
image: sub2api:latest
build:
context: ..
dockerfile: Dockerfile
container_name: sub2api
restart: unless-stopped
ulimits:
nofile:
soft: 100000
hard: 100000
ports:
- "${BIND_HOST:-0.0.0.0}:${SERVER_PORT:-8080}:8080"
volumes:
# Data persistence (config.yaml will be auto-generated here)
- sub2api_data:/app/data
# Mount custom config.yaml (optional, overrides auto-generated config)
# - ./config.yaml:/app/data/config.yaml:ro
environment:
# =======================================================================
# Auto Setup (REQUIRED for Docker deployment)
# =======================================================================
- AUTO_SETUP=true
# =======================================================================
# Server Configuration
# =======================================================================
- SERVER_HOST=0.0.0.0
- SERVER_PORT=8080
- SERVER_MODE=${SERVER_MODE:-release}
- RUN_MODE=${RUN_MODE:-standard}
# =======================================================================
# Database Configuration (PostgreSQL)
# =======================================================================
- DATABASE_HOST=postgres
- DATABASE_PORT=5432
- DATABASE_USER=${POSTGRES_USER:-sub2api}
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
- DATABASE_SSLMODE=disable
- DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
- DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
- DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
- DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
# =======================================================================
# Redis Configuration
# =======================================================================
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
- REDIS_DB=${REDIS_DB:-0}
- REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
- REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
# =======================================================================
# Admin Account (auto-created on first run)
# =======================================================================
- ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local}
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-}
# =======================================================================
# JWT Configuration
# =======================================================================
# Leave empty to auto-generate (recommended)
- JWT_SECRET=${JWT_SECRET:-}
- JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24}
# =======================================================================
# Timezone Configuration
# This affects ALL time operations in the application:
# - Database timestamps
# - Usage statistics "today" boundary
# - Subscription expiry times
# - Log timestamps
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
# =======================================================================
- TZ=${TZ:-Asia/Shanghai}
# =======================================================================
# Gemini OAuth Configuration (for Gemini accounts)
# =======================================================================
- GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-}
- GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-}
- GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-}
- GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
# Built-in OAuth client secrets (optional)
# SECURITY: This repo does not embed third-party client_secret.
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
# =======================================================================
# Security Configuration (URL Allowlist)
# =======================================================================
# Allow private IP addresses for CRS sync (for internal deployments)
- SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=${SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS:-true}
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
networks:
- sub2api-network
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 30s
# ===========================================================================
# PostgreSQL Database
# ===========================================================================
postgres:
image: postgres:18-alpine
container_name: sub2api-postgres
restart: unless-stopped
ulimits:
nofile:
soft: 100000
hard: 100000
volumes:
- postgres_data:/var/lib/postgresql/data
environment:
# postgres:18-alpine 默认 PGDATA=/var/lib/postgresql/18/docker位于镜像声明的匿名卷 /var/lib/postgresql 内)。
# 若不显式设置 PGDATA则即使挂载了 postgres_data 到 /var/lib/postgresql/data数据也不会落盘到该命名卷
# docker compose down/up 后会触发 initdb 重新初始化,导致用户/密码等数据丢失。
- PGDATA=/var/lib/postgresql/data
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
- TZ=${TZ:-Asia/Shanghai}
networks:
- sub2api-network
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
# 注意:不暴露端口到宿主机,应用通过内部网络连接
# 如需调试可临时添加ports: ["127.0.0.1:5433:5432"]
# ===========================================================================
# Redis Cache
# ===========================================================================
redis:
image: redis:8-alpine
container_name: sub2api-redis
restart: unless-stopped
ulimits:
nofile:
soft: 100000
hard: 100000
volumes:
- redis_data:/data
command: >
redis-server
--save 60 1
--appendonly yes
--appendfsync everysec
${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}}
environment:
- TZ=${TZ:-Asia/Shanghai}
# REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag)
- REDISCLI_AUTH=${REDIS_PASSWORD:-}
networks:
- sub2api-network
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
start_period: 5s
# =============================================================================
# Volumes
# =============================================================================
volumes:
sub2api_data:
driver: local
postgres_data:
driver: local
redis_data:
driver: local
# =============================================================================
# Networks
# =============================================================================
networks:
sub2api-network:
driver: bridge

Some files were not shown because too many files have changed in this diff Show More