mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-08 01:00:21 +08:00
Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bab4bb9904 | ||
|
|
33bae6f49b | ||
|
|
32d619a56b | ||
|
|
642432cf2a | ||
|
|
61e9598b08 | ||
|
|
d4e34c7514 | ||
|
|
bfe7a5e452 | ||
|
|
77d916ffec | ||
|
|
831abf7977 | ||
|
|
817a491087 | ||
|
|
9a8dacc514 | ||
|
|
8adf80d98b | ||
|
|
62686a6213 | ||
|
|
3a089242f8 | ||
|
|
9d70c38504 | ||
|
|
aeb464f3ca | ||
|
|
7076717b20 | ||
|
|
c0a4fcea0a | ||
|
|
aa2b195c86 | ||
|
|
1d0872e7ca | ||
|
|
33988637b5 | ||
|
|
d4f6ad7225 | ||
|
|
078fefed03 | ||
|
|
5b10af85b4 | ||
|
|
4caf95e5dd | ||
|
|
8e1bcf53bb | ||
|
|
064f9be7e4 | ||
|
|
adcfb44cb7 | ||
|
|
3d79773ba2 | ||
|
|
6aa8cbbf20 | ||
|
|
742e73c9c2 | ||
|
|
f8de2bdedc | ||
|
|
59879b7fa7 | ||
|
|
27abae21b8 | ||
|
|
0819c8a51a | ||
|
|
9dcd3cd491 | ||
|
|
49767cccd2 | ||
|
|
29fb447daa | ||
|
|
f6fe5b552d | ||
|
|
bd0801a887 | ||
|
|
05b1c66aa8 | ||
|
|
80ae592c23 | ||
|
|
ba6de4c4d4 | ||
|
|
46ea9170cb | ||
|
|
7d318aeefa | ||
|
|
0aa3cf677a | ||
|
|
72961c5858 | ||
|
|
a05711a37a | ||
|
|
efc9e1d673 |
105
AGENTS.md
105
AGENTS.md
@@ -1,105 +0,0 @@
|
|||||||
# Repository Guidelines
|
|
||||||
|
|
||||||
## Project Structure & Module Organization
|
|
||||||
- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output.
|
|
||||||
- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`.
|
|
||||||
- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`).
|
|
||||||
- `openspec/`: Spec-driven change docs (`changes/<id>/{proposal,design,tasks}.md`).
|
|
||||||
- `tools/`: Utility scripts (security/perf checks).
|
|
||||||
|
|
||||||
## Build, Test, and Development Commands
|
|
||||||
```bash
|
|
||||||
make build # Build backend + frontend
|
|
||||||
make test # Backend tests + frontend lint/typecheck
|
|
||||||
cd backend && make build # Build backend binary
|
|
||||||
cd backend && make test-unit # Go unit tests
|
|
||||||
cd backend && make test-integration # Go integration tests
|
|
||||||
cd backend && make test # go test ./... + golangci-lint
|
|
||||||
cd frontend && pnpm install --frozen-lockfile
|
|
||||||
cd frontend && pnpm dev # Vite dev server
|
|
||||||
cd frontend && pnpm build # Type-check + production build
|
|
||||||
cd frontend && pnpm test:run # Vitest run
|
|
||||||
cd frontend && pnpm test:coverage # Vitest + coverage report
|
|
||||||
python3 tools/secret_scan.py # Secret scan
|
|
||||||
```
|
|
||||||
|
|
||||||
## Coding Style & Naming Conventions
|
|
||||||
- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`).
|
|
||||||
- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard).
|
|
||||||
- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`.
|
|
||||||
- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`.
|
|
||||||
|
|
||||||
## Go & Frontend Development Standards
|
|
||||||
- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears.
|
|
||||||
- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency.
|
|
||||||
- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable.
|
|
||||||
|
|
||||||
### Go Performance Rules
|
|
||||||
- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code.
|
|
||||||
- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable.
|
|
||||||
- Every external I/O path must use `context.Context` with explicit timeout/cancel.
|
|
||||||
- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources.
|
|
||||||
- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`).
|
|
||||||
- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`.
|
|
||||||
- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths.
|
|
||||||
- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after).
|
|
||||||
- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out.
|
|
||||||
- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention.
|
|
||||||
- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible.
|
|
||||||
- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary.
|
|
||||||
- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions.
|
|
||||||
- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain.
|
|
||||||
- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible.
|
|
||||||
- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled.
|
|
||||||
|
|
||||||
### Data Access & Cache Rules
|
|
||||||
- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR.
|
|
||||||
- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints.
|
|
||||||
- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths.
|
|
||||||
- Keep transactions short; never perform external RPC/network calls inside DB transactions.
|
|
||||||
- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`).
|
|
||||||
- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks.
|
|
||||||
- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd.
|
|
||||||
- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths.
|
|
||||||
|
|
||||||
### Frontend Performance Rules
|
|
||||||
- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components.
|
|
||||||
- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs.
|
|
||||||
- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it.
|
|
||||||
- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed.
|
|
||||||
- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows).
|
|
||||||
- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking.
|
|
||||||
- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review.
|
|
||||||
- Avoid expensive inline computations in templates; move to cached `computed` selectors.
|
|
||||||
- Keep state normalized; avoid duplicated derived state across multiple stores/components.
|
|
||||||
- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import.
|
|
||||||
- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`.
|
|
||||||
- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks.
|
|
||||||
- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes.
|
|
||||||
|
|
||||||
### Performance Budget & PR Evidence
|
|
||||||
- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline.
|
|
||||||
- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval.
|
|
||||||
- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table).
|
|
||||||
- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output).
|
|
||||||
|
|
||||||
### Quality Gate
|
|
||||||
- Any changed code must include new or updated unit tests.
|
|
||||||
- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules).
|
|
||||||
- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description.
|
|
||||||
|
|
||||||
## Testing Guidelines
|
|
||||||
- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant.
|
|
||||||
- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`.
|
|
||||||
- Enforce unit-test and coverage rules defined in `Quality Gate`.
|
|
||||||
- Before opening a PR, run `make test` plus targeted tests for touched areas.
|
|
||||||
|
|
||||||
## Commit & Pull Request Guidelines
|
|
||||||
- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`.
|
|
||||||
- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes.
|
|
||||||
- For behavior/API changes, add or update `openspec/changes/...` artifacts.
|
|
||||||
- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR.
|
|
||||||
|
|
||||||
## Security & Configuration Tips
|
|
||||||
- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials.
|
|
||||||
- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev.
|
|
||||||
@@ -137,8 +137,6 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
|
|||||||
|
|
||||||
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
|
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
|
||||||
|
|
||||||
如果你的服务器是 **Ubuntu 24.04**,建议直接参考:`deploy/ubuntu24-docker-compose-aicodex.md`,其中包含「安装最新版 Docker + docker-compose-aicodex.yml 部署」的完整步骤。
|
|
||||||
|
|
||||||
#### 前置条件
|
#### 前置条件
|
||||||
|
|
||||||
- Docker 20.10+
|
- Docker 20.10+
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ func provideCleanup(
|
|||||||
geminiOAuth *service.GeminiOAuthService,
|
geminiOAuth *service.GeminiOAuthService,
|
||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
openAIGateway *service.OpenAIGatewayService,
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
|
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||||
) func() {
|
) func() {
|
||||||
return func() {
|
return func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
@@ -216,6 +217,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"ScheduledTestRunnerService", func() error {
|
||||||
|
if scheduledTestRunner != nil {
|
||||||
|
scheduledTestRunner.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
infraSteps := []cleanupStep{
|
infraSteps := []cleanupStep{
|
||||||
|
|||||||
@@ -195,7 +195,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
||||||
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
||||||
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
|
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
|
scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db)
|
||||||
|
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||||
|
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||||
|
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||||
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||||
@@ -225,7 +229,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
|
||||||
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -273,6 +278,7 @@ func provideCleanup(
|
|||||||
geminiOAuth *service.GeminiOAuthService,
|
geminiOAuth *service.GeminiOAuthService,
|
||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
openAIGateway *service.OpenAIGatewayService,
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
|
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||||
) func() {
|
) func() {
|
||||||
return func() {
|
return func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
@@ -402,6 +408,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"ScheduledTestRunnerService", func() error {
|
||||||
|
if scheduledTestRunner != nil {
|
||||||
|
scheduledTestRunner.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
infraSteps := []cleanupStep{
|
infraSteps := []cleanupStep{
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
|||||||
geminiOAuthSvc,
|
geminiOAuthSvc,
|
||||||
antigravityOAuthSvc,
|
antigravityOAuthSvc,
|
||||||
nil, // openAIGateway
|
nil, // openAIGateway
|
||||||
|
nil, // scheduledTestRunner
|
||||||
)
|
)
|
||||||
|
|
||||||
require.NotPanics(t, func() {
|
require.NotPanics(t, func() {
|
||||||
|
|||||||
@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
|||||||
type GatewayOpenAIWSConfig struct {
|
type GatewayOpenAIWSConfig struct {
|
||||||
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
||||||
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
||||||
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
|
// IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough)
|
||||||
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
||||||
// Enabled: 全局总开关(默认 true)
|
// Enabled: 全局总开关(默认 true)
|
||||||
Enabled bool `mapstructure:"enabled"`
|
Enabled bool `mapstructure:"enabled"`
|
||||||
@@ -1227,7 +1227,7 @@ func setDefaults() {
|
|||||||
|
|
||||||
// Ops (vNext)
|
// Ops (vNext)
|
||||||
viper.SetDefault("ops.enabled", true)
|
viper.SetDefault("ops.enabled", true)
|
||||||
viper.SetDefault("ops.use_preaggregated_tables", false)
|
viper.SetDefault("ops.use_preaggregated_tables", true)
|
||||||
viper.SetDefault("ops.cleanup.enabled", true)
|
viper.SetDefault("ops.cleanup.enabled", true)
|
||||||
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
|
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
|
||||||
// Retention days: vNext defaults to 30 days across ops datasets.
|
// Retention days: vNext defaults to 30 days across ops datasets.
|
||||||
@@ -1335,7 +1335,7 @@ func setDefaults() {
|
|||||||
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||||
viper.SetDefault("gateway.openai_ws.enabled", true)
|
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||||
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
||||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
|
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
|
||||||
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
||||||
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
||||||
viper.SetDefault("gateway.openai_ws.force_http", false)
|
viper.SetDefault("gateway.openai_ws.force_http", false)
|
||||||
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
|
|||||||
}
|
}
|
||||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
||||||
switch mode {
|
switch mode {
|
||||||
case "off", "shared", "dedicated":
|
case "off", "ctx_pool", "passthrough":
|
||||||
|
case "shared", "dedicated":
|
||||||
|
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
|
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
||||||
|
|||||||
@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
|
|||||||
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
||||||
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
||||||
}
|
}
|
||||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
|
if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
|
||||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
|
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
|
|||||||
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ingress_mode_default 必须为 off|shared|dedicated",
|
name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
|
||||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
||||||
wantErr: "gateway.openai_ws.ingress_mode_default",
|
wantErr: "gateway.openai_ws.ingress_mode_default",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -217,6 +217,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
if len(search) > 100 {
|
if len(search) > 100 {
|
||||||
search = search[:100]
|
search = search[:100]
|
||||||
}
|
}
|
||||||
|
lite := parseBoolQueryWithDefault(c.Query("lite"), false)
|
||||||
|
|
||||||
var groupID int64
|
var groupID int64
|
||||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||||
@@ -235,10 +236,16 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
accountIDs[i] = acc.ID
|
accountIDs[i] = acc.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
|
concurrencyCounts := make(map[int64]int)
|
||||||
if err != nil {
|
var windowCosts map[int64]float64
|
||||||
// Log error but don't fail the request, just use 0 for all
|
var activeSessions map[int64]int
|
||||||
concurrencyCounts = make(map[int64]int)
|
var rpmCounts map[int64]int
|
||||||
|
|
||||||
|
// 始终获取并发数(Redis ZCARD,极低开销)
|
||||||
|
if h.concurrencyService != nil {
|
||||||
|
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
|
||||||
|
concurrencyCounts = cc
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||||
@@ -262,12 +269,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 并行获取窗口费用、活跃会话数和 RPM 计数
|
// 始终获取 RPM 计数(Redis GET,极低开销)
|
||||||
var windowCosts map[int64]float64
|
|
||||||
var activeSessions map[int64]int
|
|
||||||
var rpmCounts map[int64]int
|
|
||||||
|
|
||||||
// 获取 RPM 计数(批量查询)
|
|
||||||
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
||||||
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
||||||
if rpmCounts == nil {
|
if rpmCounts == nil {
|
||||||
@@ -275,7 +277,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
// 始终获取活跃会话数(Redis ZCARD,低开销)
|
||||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||||
if activeSessions == nil {
|
if activeSessions == nil {
|
||||||
@@ -283,8 +285,8 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取窗口费用(并行查询)
|
// 仅非 lite 模式获取窗口费用(PostgreSQL 聚合查询,高开销)
|
||||||
if len(windowCostAccountIDs) > 0 {
|
if !lite && len(windowCostAccountIDs) > 0 {
|
||||||
windowCosts = make(map[int64]float64)
|
windowCosts = make(map[int64]float64)
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||||
@@ -344,7 +346,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
result[i] = item
|
result[i] = item
|
||||||
}
|
}
|
||||||
|
|
||||||
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search)
|
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite)
|
||||||
if etag != "" {
|
if etag != "" {
|
||||||
c.Header("ETag", etag)
|
c.Header("ETag", etag)
|
||||||
c.Header("Vary", "If-None-Match")
|
c.Header("Vary", "If-None-Match")
|
||||||
@@ -362,6 +364,7 @@ func buildAccountsListETag(
|
|||||||
total int64,
|
total int64,
|
||||||
page, pageSize int,
|
page, pageSize int,
|
||||||
platform, accountType, status, search string,
|
platform, accountType, status, search string,
|
||||||
|
lite bool,
|
||||||
) string {
|
) string {
|
||||||
payload := struct {
|
payload := struct {
|
||||||
Total int64 `json:"total"`
|
Total int64 `json:"total"`
|
||||||
@@ -371,6 +374,7 @@ func buildAccountsListETag(
|
|||||||
AccountType string `json:"type"`
|
AccountType string `json:"type"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Search string `json:"search"`
|
Search string `json:"search"`
|
||||||
|
Lite bool `json:"lite"`
|
||||||
Items []AccountWithConcurrency `json:"items"`
|
Items []AccountWithConcurrency `json:"items"`
|
||||||
}{
|
}{
|
||||||
Total: total,
|
Total: total,
|
||||||
@@ -380,6 +384,7 @@ func buildAccountsListETag(
|
|||||||
AccountType: accountType,
|
AccountType: accountType,
|
||||||
Status: status,
|
Status: status,
|
||||||
Search: search,
|
Search: search,
|
||||||
|
Lite: lite,
|
||||||
Items: items,
|
Items: items,
|
||||||
}
|
}
|
||||||
raw, err := json.Marshal(payload)
|
raw, err := json.Marshal(payload)
|
||||||
@@ -1398,18 +1403,41 @@ func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req.AccountIDs) == 0 {
|
accountIDs := normalizeInt64IDList(req.AccountIDs)
|
||||||
|
if len(accountIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
|
cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs)
|
||||||
|
if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok {
|
||||||
|
if cached.ETag != "" {
|
||||||
|
c.Header("ETag", cached.ETag)
|
||||||
|
c.Header("Vary", "If-None-Match")
|
||||||
|
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||||
|
c.Status(http.StatusNotModified)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", "hit")
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, gin.H{"stats": stats})
|
payload := gin.H{"stats": stats}
|
||||||
|
cached := accountTodayStatsBatchCache.Set(cacheKey, payload)
|
||||||
|
if cached.ETag != "" {
|
||||||
|
c.Header("ETag", cached.ETag)
|
||||||
|
c.Header("Vary", "If-None-Match")
|
||||||
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
|
response.Success(c, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSchedulableRequest represents the request body for setting schedulable status
|
// SetSchedulableRequest represents the request body for setting schedulable status
|
||||||
|
|||||||
25
backend/internal/handler/admin/account_today_stats_cache.go
Normal file
25
backend/internal/handler/admin/account_today_stats_cache.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second)
|
||||||
|
|
||||||
|
func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string {
|
||||||
|
if len(accountIDs) == 0 {
|
||||||
|
return "accounts_today_stats_empty"
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(accountIDs) * 6)
|
||||||
|
_, _ = b.WriteString("accounts_today_stats:")
|
||||||
|
for i, id := range accountIDs {
|
||||||
|
if i > 0 {
|
||||||
|
_ = b.WriteByte(',')
|
||||||
|
}
|
||||||
|
_, _ = b.WriteString(strconv.FormatInt(id, 10))
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -460,6 +461,9 @@ type BatchUsersUsageRequest struct {
|
|||||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||||
|
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||||
|
|
||||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||||
// POST /api/v1/admin/dashboard/users-usage
|
// POST /api/v1/admin/dashboard/users-usage
|
||||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||||
@@ -469,18 +473,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req.UserIDs) == 0 {
|
userIDs := normalizeInt64IDList(req.UserIDs)
|
||||||
|
if len(userIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
|
keyRaw, _ := json.Marshal(struct {
|
||||||
|
UserIDs []int64 `json:"user_ids"`
|
||||||
|
}{
|
||||||
|
UserIDs: userIDs,
|
||||||
|
})
|
||||||
|
cacheKey := string(keyRaw)
|
||||||
|
if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok {
|
||||||
|
c.Header("X-Snapshot-Cache", "hit")
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get user usage stats")
|
response.Error(c, 500, "Failed to get user usage stats")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, gin.H{"stats": stats})
|
payload := gin.H{"stats": stats}
|
||||||
|
dashboardBatchUsersUsageCache.Set(cacheKey, payload)
|
||||||
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
|
response.Success(c, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
|
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
|
||||||
@@ -497,16 +517,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req.APIKeyIDs) == 0 {
|
apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs)
|
||||||
|
if len(apiKeyIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
|
keyRaw, _ := json.Marshal(struct {
|
||||||
|
APIKeyIDs []int64 `json:"api_key_ids"`
|
||||||
|
}{
|
||||||
|
APIKeyIDs: apiKeyIDs,
|
||||||
|
})
|
||||||
|
cacheKey := string(keyRaw)
|
||||||
|
if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok {
|
||||||
|
c.Header("X-Snapshot-Cache", "hit")
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get API key usage stats")
|
response.Error(c, 500, "Failed to get API key usage stats")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, gin.H{"stats": stats})
|
payload := gin.H{"stats": stats}
|
||||||
|
dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload)
|
||||||
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
|
response.Success(c, payload)
|
||||||
}
|
}
|
||||||
|
|||||||
292
backend/internal/handler/admin/dashboard_snapshot_v2_handler.go
Normal file
292
backend/internal/handler/admin/dashboard_snapshot_v2_handler.go
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||||
|
|
||||||
|
type dashboardSnapshotV2Stats struct {
|
||||||
|
usagestats.DashboardStats
|
||||||
|
Uptime int64 `json:"uptime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardSnapshotV2Response struct {
|
||||||
|
GeneratedAt string `json:"generated_at"`
|
||||||
|
|
||||||
|
StartDate string `json:"start_date"`
|
||||||
|
EndDate string `json:"end_date"`
|
||||||
|
Granularity string `json:"granularity"`
|
||||||
|
|
||||||
|
Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"`
|
||||||
|
Trend []usagestats.TrendDataPoint `json:"trend,omitempty"`
|
||||||
|
Models []usagestats.ModelStat `json:"models,omitempty"`
|
||||||
|
Groups []usagestats.GroupStat `json:"groups,omitempty"`
|
||||||
|
UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardSnapshotV2Filters struct {
|
||||||
|
UserID int64
|
||||||
|
APIKeyID int64
|
||||||
|
AccountID int64
|
||||||
|
GroupID int64
|
||||||
|
Model string
|
||||||
|
RequestType *int16
|
||||||
|
Stream *bool
|
||||||
|
BillingType *int8
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardSnapshotV2CacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
Granularity string `json:"granularity"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
AccountID int64 `json:"account_id"`
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
RequestType *int16 `json:"request_type"`
|
||||||
|
Stream *bool `json:"stream"`
|
||||||
|
BillingType *int8 `json:"billing_type"`
|
||||||
|
IncludeStats bool `json:"include_stats"`
|
||||||
|
IncludeTrend bool `json:"include_trend"`
|
||||||
|
IncludeModels bool `json:"include_models"`
|
||||||
|
IncludeGroups bool `json:"include_groups"`
|
||||||
|
IncludeUsersTrend bool `json:"include_users_trend"`
|
||||||
|
UsersTrendLimit int `json:"users_trend_limit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||||
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day"))
|
||||||
|
if granularity != "hour" {
|
||||||
|
granularity = "day"
|
||||||
|
}
|
||||||
|
|
||||||
|
includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true)
|
||||||
|
includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true)
|
||||||
|
includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true)
|
||||||
|
includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false)
|
||||||
|
includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false)
|
||||||
|
usersTrendLimit := 12
|
||||||
|
if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" {
|
||||||
|
if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 {
|
||||||
|
usersTrendLimit = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filters, err := parseDashboardSnapshotV2Filters(c)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Granularity: granularity,
|
||||||
|
UserID: filters.UserID,
|
||||||
|
APIKeyID: filters.APIKeyID,
|
||||||
|
AccountID: filters.AccountID,
|
||||||
|
GroupID: filters.GroupID,
|
||||||
|
Model: filters.Model,
|
||||||
|
RequestType: filters.RequestType,
|
||||||
|
Stream: filters.Stream,
|
||||||
|
BillingType: filters.BillingType,
|
||||||
|
IncludeStats: includeStats,
|
||||||
|
IncludeTrend: includeTrend,
|
||||||
|
IncludeModels: includeModels,
|
||||||
|
IncludeGroups: includeGroups,
|
||||||
|
IncludeUsersTrend: includeUsersTrend,
|
||||||
|
UsersTrendLimit: usersTrendLimit,
|
||||||
|
})
|
||||||
|
cacheKey := string(keyRaw)
|
||||||
|
|
||||||
|
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||||
|
if cached.ETag != "" {
|
||||||
|
c.Header("ETag", cached.ETag)
|
||||||
|
c.Header("Vary", "If-None-Match")
|
||||||
|
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||||
|
c.Status(http.StatusNotModified)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", "hit")
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &dashboardSnapshotV2Response{
|
||||||
|
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||||
|
StartDate: startTime.Format("2006-01-02"),
|
||||||
|
EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||||
|
Granularity: granularity,
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeStats {
|
||||||
|
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Stats = &dashboardSnapshotV2Stats{
|
||||||
|
DashboardStats: *stats,
|
||||||
|
Uptime: int64(time.Since(h.startTime).Seconds()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeTrend {
|
||||||
|
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
||||||
|
c.Request.Context(),
|
||||||
|
startTime,
|
||||||
|
endTime,
|
||||||
|
granularity,
|
||||||
|
filters.UserID,
|
||||||
|
filters.APIKeyID,
|
||||||
|
filters.AccountID,
|
||||||
|
filters.GroupID,
|
||||||
|
filters.Model,
|
||||||
|
filters.RequestType,
|
||||||
|
filters.Stream,
|
||||||
|
filters.BillingType,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get usage trend")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Trend = trend
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeModels {
|
||||||
|
models, err := h.dashboardService.GetModelStatsWithFilters(
|
||||||
|
c.Request.Context(),
|
||||||
|
startTime,
|
||||||
|
endTime,
|
||||||
|
filters.UserID,
|
||||||
|
filters.APIKeyID,
|
||||||
|
filters.AccountID,
|
||||||
|
filters.GroupID,
|
||||||
|
filters.RequestType,
|
||||||
|
filters.Stream,
|
||||||
|
filters.BillingType,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get model statistics")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Models = models
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeGroups {
|
||||||
|
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
||||||
|
c.Request.Context(),
|
||||||
|
startTime,
|
||||||
|
endTime,
|
||||||
|
filters.UserID,
|
||||||
|
filters.APIKeyID,
|
||||||
|
filters.AccountID,
|
||||||
|
filters.GroupID,
|
||||||
|
filters.RequestType,
|
||||||
|
filters.Stream,
|
||||||
|
filters.BillingType,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get group statistics")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Groups = groups
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeUsersTrend {
|
||||||
|
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
||||||
|
c.Request.Context(),
|
||||||
|
startTime,
|
||||||
|
endTime,
|
||||||
|
granularity,
|
||||||
|
usersTrendLimit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get user usage trend")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.UsersTrend = usersTrend
|
||||||
|
}
|
||||||
|
|
||||||
|
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||||
|
if cached.ETag != "" {
|
||||||
|
c.Header("ETag", cached.ETag)
|
||||||
|
c.Header("Vary", "If-None-Match")
|
||||||
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
|
response.Success(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||||
|
filters := &dashboardSnapshotV2Filters{
|
||||||
|
Model: strings.TrimSpace(c.Query("model")),
|
||||||
|
}
|
||||||
|
|
||||||
|
if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" {
|
||||||
|
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
filters.UserID = id
|
||||||
|
}
|
||||||
|
if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" {
|
||||||
|
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
filters.APIKeyID = id
|
||||||
|
}
|
||||||
|
if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" {
|
||||||
|
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
filters.AccountID = id
|
||||||
|
}
|
||||||
|
if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" {
|
||||||
|
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
filters.GroupID = id
|
||||||
|
}
|
||||||
|
|
||||||
|
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||||
|
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
value := int16(parsed)
|
||||||
|
filters.RequestType = &value
|
||||||
|
} else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" {
|
||||||
|
streamVal, err := strconv.ParseBool(streamStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
filters.Stream = &streamVal
|
||||||
|
}
|
||||||
|
|
||||||
|
if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" {
|
||||||
|
v, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bt := int8(v)
|
||||||
|
filters.BillingType = &bt
|
||||||
|
}
|
||||||
|
|
||||||
|
return filters, nil
|
||||||
|
}
|
||||||
25
backend/internal/handler/admin/id_list_utils.go
Normal file
25
backend/internal/handler/admin/id_list_utils.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import "sort"
|
||||||
|
|
||||||
|
func normalizeInt64IDList(ids []int64) []int64 {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]int64, 0, len(ids))
|
||||||
|
seen := make(map[int64]struct{}, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
if id <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
|
||||||
|
return out
|
||||||
|
}
|
||||||
57
backend/internal/handler/admin/id_list_utils_test.go
Normal file
57
backend/internal/handler/admin/id_list_utils_test.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeInt64IDList(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in []int64
|
||||||
|
want []int64
|
||||||
|
}{
|
||||||
|
{"nil input", nil, nil},
|
||||||
|
{"empty input", []int64{}, nil},
|
||||||
|
{"single element", []int64{5}, []int64{5}},
|
||||||
|
{"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}},
|
||||||
|
{"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}},
|
||||||
|
{"zero filtered", []int64{0, 1, 2}, []int64{1, 2}},
|
||||||
|
{"negative filtered", []int64{-5, -1, 3}, []int64{3}},
|
||||||
|
{"all invalid", []int64{0, -1, -2}, []int64{}},
|
||||||
|
{"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := normalizeInt64IDList(tc.in)
|
||||||
|
if tc.want == nil {
|
||||||
|
require.Nil(t, got)
|
||||||
|
} else {
|
||||||
|
require.Equal(t, tc.want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ids []int64
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"empty", nil, "accounts_today_stats_empty"},
|
||||||
|
{"single", []int64{42}, "accounts_today_stats:42"},
|
||||||
|
{"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := buildAccountTodayStatsBatchCacheKey(tc.ids)
|
||||||
|
require.Equal(t, tc.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
145
backend/internal/handler/admin/ops_snapshot_v2_handler.go
Normal file
145
backend/internal/handler/admin/ops_snapshot_v2_handler.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||||
|
|
||||||
|
type opsDashboardSnapshotV2Response struct {
|
||||||
|
GeneratedAt string `json:"generated_at"`
|
||||||
|
|
||||||
|
Overview *service.OpsDashboardOverview `json:"overview"`
|
||||||
|
ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"`
|
||||||
|
ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type opsDashboardSnapshotV2CacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
GroupID *int64 `json:"group_id"`
|
||||||
|
QueryMode service.OpsQueryMode `json:"mode"`
|
||||||
|
BucketSecond int `json:"bucket_second"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request.
|
||||||
|
// GET /api/v1/admin/ops/dashboard/snapshot-v2
|
||||||
|
func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsDashboardFilter{
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Platform: strings.TrimSpace(c.Query("platform")),
|
||||||
|
QueryMode: parseOpsQueryMode(c),
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
|
||||||
|
|
||||||
|
keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Platform: filter.Platform,
|
||||||
|
GroupID: filter.GroupID,
|
||||||
|
QueryMode: filter.QueryMode,
|
||||||
|
BucketSecond: bucketSeconds,
|
||||||
|
})
|
||||||
|
cacheKey := string(keyRaw)
|
||||||
|
|
||||||
|
if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||||
|
if cached.ETag != "" {
|
||||||
|
c.Header("ETag", cached.ETag)
|
||||||
|
c.Header("Vary", "If-None-Match")
|
||||||
|
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||||
|
c.Status(http.StatusNotModified)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", "hit")
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
overview *service.OpsDashboardOverview
|
||||||
|
trend *service.OpsThroughputTrendResponse
|
||||||
|
errTrend *service.OpsErrorTrendResponse
|
||||||
|
)
|
||||||
|
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||||
|
g.Go(func() error {
|
||||||
|
f := *filter
|
||||||
|
result, err := h.opsService.GetDashboardOverview(gctx, &f)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
overview = result
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
g.Go(func() error {
|
||||||
|
f := *filter
|
||||||
|
result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
trend = result
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
g.Go(func() error {
|
||||||
|
f := *filter
|
||||||
|
result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
errTrend = result
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &opsDashboardSnapshotV2Response{
|
||||||
|
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||||
|
Overview: overview,
|
||||||
|
ThroughputTrend: trend,
|
||||||
|
ErrorTrend: errTrend,
|
||||||
|
}
|
||||||
|
|
||||||
|
cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||||
|
if cached.ETag != "" {
|
||||||
|
c.Header("ETag", cached.ETag)
|
||||||
|
c.Header("Vary", "If-None-Match")
|
||||||
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
|
response.Success(c, resp)
|
||||||
|
}
|
||||||
155
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
155
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ScheduledTestHandler handles admin scheduled-test-plan management.
|
||||||
|
type ScheduledTestHandler struct {
|
||||||
|
scheduledTestSvc *service.ScheduledTestService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewScheduledTestHandler creates a new ScheduledTestHandler.
|
||||||
|
func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler {
|
||||||
|
return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc}
|
||||||
|
}
|
||||||
|
|
||||||
|
type createScheduledTestPlanRequest struct {
|
||||||
|
AccountID int64 `json:"account_id" binding:"required"`
|
||||||
|
ModelID string `json:"model_id"`
|
||||||
|
CronExpression string `json:"cron_expression" binding:"required"`
|
||||||
|
Enabled *bool `json:"enabled"`
|
||||||
|
MaxResults int `json:"max_results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type updateScheduledTestPlanRequest struct {
|
||||||
|
ModelID string `json:"model_id"`
|
||||||
|
CronExpression string `json:"cron_expression"`
|
||||||
|
Enabled *bool `json:"enabled"`
|
||||||
|
MaxResults int `json:"max_results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||||
|
func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) {
|
||||||
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid account id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, plans)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create POST /admin/scheduled-test-plans
|
||||||
|
func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||||
|
var req createScheduledTestPlanRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
plan := &service.ScheduledTestPlan{
|
||||||
|
AccountID: req.AccountID,
|
||||||
|
ModelID: req.ModelID,
|
||||||
|
CronExpression: req.CronExpression,
|
||||||
|
Enabled: true,
|
||||||
|
MaxResults: req.MaxResults,
|
||||||
|
}
|
||||||
|
if req.Enabled != nil {
|
||||||
|
plan.Enabled = *req.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, created)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update PUT /admin/scheduled-test-plans/:id
|
||||||
|
func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||||
|
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid plan id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "plan not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req updateScheduledTestPlanRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ModelID != "" {
|
||||||
|
existing.ModelID = req.ModelID
|
||||||
|
}
|
||||||
|
if req.CronExpression != "" {
|
||||||
|
existing.CronExpression = req.CronExpression
|
||||||
|
}
|
||||||
|
if req.Enabled != nil {
|
||||||
|
existing.Enabled = *req.Enabled
|
||||||
|
}
|
||||||
|
if req.MaxResults > 0 {
|
||||||
|
existing.MaxResults = req.MaxResults
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete DELETE /admin/scheduled-test-plans/:id
|
||||||
|
func (h *ScheduledTestHandler) Delete(c *gin.Context) {
|
||||||
|
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid plan id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil {
|
||||||
|
response.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListResults GET /admin/scheduled-test-plans/:id/results
|
||||||
|
func (h *ScheduledTestHandler) ListResults(c *gin.Context) {
|
||||||
|
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid plan id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := 50
|
||||||
|
if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 {
|
||||||
|
limit = l
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, results)
|
||||||
|
}
|
||||||
@@ -77,6 +77,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
response.Success(c, dto.SystemSettings{
|
response.Success(c, dto.SystemSettings{
|
||||||
RegistrationEnabled: settings.RegistrationEnabled,
|
RegistrationEnabled: settings.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||||
|
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
@@ -130,12 +131,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
// UpdateSettingsRequest 更新设置请求
|
// UpdateSettingsRequest 更新设置请求
|
||||||
type UpdateSettingsRequest struct {
|
type UpdateSettingsRequest struct {
|
||||||
// 注册设置
|
// 注册设置
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SMTPHost string `json:"smtp_host"`
|
SMTPHost string `json:"smtp_host"`
|
||||||
@@ -426,50 +428,51 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
settings := &service.SystemSettings{
|
settings := &service.SystemSettings{
|
||||||
RegistrationEnabled: req.RegistrationEnabled,
|
RegistrationEnabled: req.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||||
TotpEnabled: req.TotpEnabled,
|
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||||
SMTPHost: req.SMTPHost,
|
TotpEnabled: req.TotpEnabled,
|
||||||
SMTPPort: req.SMTPPort,
|
SMTPHost: req.SMTPHost,
|
||||||
SMTPUsername: req.SMTPUsername,
|
SMTPPort: req.SMTPPort,
|
||||||
SMTPPassword: req.SMTPPassword,
|
SMTPUsername: req.SMTPUsername,
|
||||||
SMTPFrom: req.SMTPFrom,
|
SMTPPassword: req.SMTPPassword,
|
||||||
SMTPFromName: req.SMTPFromName,
|
SMTPFrom: req.SMTPFrom,
|
||||||
SMTPUseTLS: req.SMTPUseTLS,
|
SMTPFromName: req.SMTPFromName,
|
||||||
TurnstileEnabled: req.TurnstileEnabled,
|
SMTPUseTLS: req.SMTPUseTLS,
|
||||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
TurnstileEnabled: req.TurnstileEnabled,
|
||||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||||
SiteName: req.SiteName,
|
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||||
SiteLogo: req.SiteLogo,
|
SiteName: req.SiteName,
|
||||||
SiteSubtitle: req.SiteSubtitle,
|
SiteLogo: req.SiteLogo,
|
||||||
APIBaseURL: req.APIBaseURL,
|
SiteSubtitle: req.SiteSubtitle,
|
||||||
ContactInfo: req.ContactInfo,
|
APIBaseURL: req.APIBaseURL,
|
||||||
DocURL: req.DocURL,
|
ContactInfo: req.ContactInfo,
|
||||||
HomeContent: req.HomeContent,
|
DocURL: req.DocURL,
|
||||||
HideCcsImportButton: req.HideCcsImportButton,
|
HomeContent: req.HomeContent,
|
||||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
HideCcsImportButton: req.HideCcsImportButton,
|
||||||
PurchaseSubscriptionURL: purchaseURL,
|
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||||
SoraClientEnabled: req.SoraClientEnabled,
|
PurchaseSubscriptionURL: purchaseURL,
|
||||||
CustomMenuItems: customMenuJSON,
|
SoraClientEnabled: req.SoraClientEnabled,
|
||||||
DefaultConcurrency: req.DefaultConcurrency,
|
CustomMenuItems: customMenuJSON,
|
||||||
DefaultBalance: req.DefaultBalance,
|
DefaultConcurrency: req.DefaultConcurrency,
|
||||||
DefaultSubscriptions: defaultSubscriptions,
|
DefaultBalance: req.DefaultBalance,
|
||||||
EnableModelFallback: req.EnableModelFallback,
|
DefaultSubscriptions: defaultSubscriptions,
|
||||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
EnableModelFallback: req.EnableModelFallback,
|
||||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||||
FallbackModelGemini: req.FallbackModelGemini,
|
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
FallbackModelGemini: req.FallbackModelGemini,
|
||||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||||
|
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||||
OpsMonitoringEnabled: func() bool {
|
OpsMonitoringEnabled: func() bool {
|
||||||
if req.OpsMonitoringEnabled != nil {
|
if req.OpsMonitoringEnabled != nil {
|
||||||
return *req.OpsMonitoringEnabled
|
return *req.OpsMonitoringEnabled
|
||||||
@@ -520,6 +523,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
response.Success(c, dto.SystemSettings{
|
response.Success(c, dto.SystemSettings{
|
||||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||||
|
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||||
@@ -598,6 +602,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||||
changed = append(changed, "email_verify_enabled")
|
changed = append(changed, "email_verify_enabled")
|
||||||
}
|
}
|
||||||
|
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
|
||||||
|
changed = append(changed, "registration_email_suffix_whitelist")
|
||||||
|
}
|
||||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||||
changed = append(changed, "password_reset_enabled")
|
changed = append(changed, "password_reset_enabled")
|
||||||
}
|
}
|
||||||
@@ -747,6 +754,18 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
|
|||||||
return normalized
|
return normalized
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func equalStringSlice(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
|
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
|
||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return false
|
return false
|
||||||
@@ -800,7 +819,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
|||||||
|
|
||||||
err := h.emailService.TestSMTPConnectionWithConfig(config)
|
err := h.emailService.TestSMTPConnectionWithConfig(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -886,7 +905,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
|||||||
`
|
`
|
||||||
|
|
||||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.BadRequest(c, "Failed to send test email: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
95
backend/internal/handler/admin/snapshot_cache.go
Normal file
95
backend/internal/handler/admin/snapshot_cache.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type snapshotCacheEntry struct {
|
||||||
|
ETag string
|
||||||
|
Payload any
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type snapshotCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
ttl time.Duration
|
||||||
|
items map[string]snapshotCacheEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = 30 * time.Second
|
||||||
|
}
|
||||||
|
return &snapshotCache{
|
||||||
|
ttl: ttl,
|
||||||
|
items: make(map[string]snapshotCacheEntry),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) {
|
||||||
|
if c == nil || key == "" {
|
||||||
|
return snapshotCacheEntry{}, false
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
entry, ok := c.items[key]
|
||||||
|
c.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return snapshotCacheEntry{}, false
|
||||||
|
}
|
||||||
|
if now.After(entry.ExpiresAt) {
|
||||||
|
c.mu.Lock()
|
||||||
|
delete(c.items, key)
|
||||||
|
c.mu.Unlock()
|
||||||
|
return snapshotCacheEntry{}, false
|
||||||
|
}
|
||||||
|
return entry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
||||||
|
if c == nil {
|
||||||
|
return snapshotCacheEntry{}
|
||||||
|
}
|
||||||
|
entry := snapshotCacheEntry{
|
||||||
|
ETag: buildETagFromAny(payload),
|
||||||
|
Payload: payload,
|
||||||
|
ExpiresAt: time.Now().Add(c.ttl),
|
||||||
|
}
|
||||||
|
if key == "" {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
c.items[key] = entry
|
||||||
|
c.mu.Unlock()
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildETagFromAny(payload any) string {
|
||||||
|
raw, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256(raw)
|
||||||
|
return "\"" + hex.EncodeToString(sum[:]) + "\""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBoolQueryWithDefault(raw string, def bool) bool {
|
||||||
|
value := strings.TrimSpace(strings.ToLower(raw))
|
||||||
|
if value == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
switch value {
|
||||||
|
case "1", "true", "yes", "on":
|
||||||
|
return true
|
||||||
|
case "0", "false", "no", "off":
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
}
|
||||||
128
backend/internal/handler/admin/snapshot_cache_test.go
Normal file
128
backend/internal/handler/admin/snapshot_cache_test.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSnapshotCache_SetAndGet(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
|
||||||
|
entry := c.Set("key1", map[string]string{"hello": "world"})
|
||||||
|
require.NotEmpty(t, entry.ETag)
|
||||||
|
require.NotNil(t, entry.Payload)
|
||||||
|
|
||||||
|
got, ok := c.Get("key1")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, entry.ETag, got.ETag)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_Expiration(t *testing.T) {
|
||||||
|
c := newSnapshotCache(1 * time.Millisecond)
|
||||||
|
|
||||||
|
c.Set("key1", "value")
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
|
||||||
|
_, ok := c.Get("key1")
|
||||||
|
require.False(t, ok, "expired entry should not be returned")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_GetEmptyKey(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
_, ok := c.Get("")
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_GetMiss(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
_, ok := c.Get("nonexistent")
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_NilReceiver(t *testing.T) {
|
||||||
|
var c *snapshotCache
|
||||||
|
_, ok := c.Get("key")
|
||||||
|
require.False(t, ok)
|
||||||
|
|
||||||
|
entry := c.Set("key", "value")
|
||||||
|
require.Empty(t, entry.ETag)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_SetEmptyKey(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
|
||||||
|
// Set with empty key should return entry but not store it
|
||||||
|
entry := c.Set("", "value")
|
||||||
|
require.NotEmpty(t, entry.ETag)
|
||||||
|
|
||||||
|
_, ok := c.Get("")
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_DefaultTTL(t *testing.T) {
|
||||||
|
c := newSnapshotCache(0)
|
||||||
|
require.Equal(t, 30*time.Second, c.ttl)
|
||||||
|
|
||||||
|
c2 := newSnapshotCache(-1 * time.Second)
|
||||||
|
require.Equal(t, 30*time.Second, c2.ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_ETagDeterministic(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
payload := map[string]int{"a": 1, "b": 2}
|
||||||
|
|
||||||
|
entry1 := c.Set("k1", payload)
|
||||||
|
entry2 := c.Set("k2", payload)
|
||||||
|
require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_ETagFormat(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
entry := c.Set("k", "test")
|
||||||
|
// ETag should be quoted hex string: "abcdef..."
|
||||||
|
require.True(t, len(entry.ETag) > 2)
|
||||||
|
require.Equal(t, byte('"'), entry.ETag[0])
|
||||||
|
require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
||||||
|
// channels are not JSON-serializable
|
||||||
|
etag := buildETagFromAny(make(chan int))
|
||||||
|
require.Empty(t, etag)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
raw string
|
||||||
|
def bool
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"empty returns default true", "", true, true},
|
||||||
|
{"empty returns default false", "", false, false},
|
||||||
|
{"1", "1", false, true},
|
||||||
|
{"true", "true", false, true},
|
||||||
|
{"TRUE", "TRUE", false, true},
|
||||||
|
{"yes", "yes", false, true},
|
||||||
|
{"on", "on", false, true},
|
||||||
|
{"0", "0", true, false},
|
||||||
|
{"false", "false", true, false},
|
||||||
|
{"FALSE", "FALSE", true, false},
|
||||||
|
{"no", "no", true, false},
|
||||||
|
{"off", "off", true, false},
|
||||||
|
{"whitespace trimmed", " true ", false, true},
|
||||||
|
{"unknown returns default true", "maybe", true, true},
|
||||||
|
{"unknown returns default false", "maybe", false, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := parseBoolQueryWithDefault(tc.raw, tc.def)
|
||||||
|
require.Equal(t, tc.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -61,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct {
|
|||||||
// GET /api/v1/admin/usage
|
// GET /api/v1/admin/usage
|
||||||
func (h *UsageHandler) List(c *gin.Context) {
|
func (h *UsageHandler) List(c *gin.Context) {
|
||||||
page, pageSize := response.ParsePagination(c)
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
exactTotal := false
|
||||||
|
if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" {
|
||||||
|
parsed, err := strconv.ParseBool(exactTotalRaw)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid exact_total value, use true or false")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
exactTotal = parsed
|
||||||
|
}
|
||||||
|
|
||||||
// Parse filters
|
// Parse filters
|
||||||
var userID, apiKeyID, accountID, groupID int64
|
var userID, apiKeyID, accountID, groupID int64
|
||||||
@@ -167,6 +176,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
StartTime: startTime,
|
StartTime: startTime,
|
||||||
EndTime: endTime,
|
EndTime: endTime,
|
||||||
|
ExactTotal: exactTotal,
|
||||||
}
|
}
|
||||||
|
|
||||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||||
|
|||||||
@@ -80,6 +80,29 @@ func TestAdminUsageListInvalidStream(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminUsageListExactTotalTrue(t *testing.T) {
|
||||||
|
repo := &adminUsageRepoCapture{}
|
||||||
|
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.True(t, repo.listFilters.ExactTotal)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUsageListInvalidExactTotal(t *testing.T) {
|
||||||
|
repo := &adminUsageRepoCapture{}
|
||||||
|
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
|
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
|
||||||
repo := &adminUsageRepoCapture{}
|
repo := &adminUsageRepoCapture{}
|
||||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct {
|
|||||||
Attributes map[int64]map[int64]string `json:"attributes"`
|
Attributes map[int64]map[int64]string `json:"attributes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var userAttributesBatchCache = newSnapshotCache(30 * time.Second)
|
||||||
|
|
||||||
// AttributeDefinitionResponse represents attribute definition response
|
// AttributeDefinitionResponse represents attribute definition response
|
||||||
type AttributeDefinitionResponse struct {
|
type AttributeDefinitionResponse struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
@@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req.UserIDs) == 0 {
|
userIDs := normalizeInt64IDList(req.UserIDs)
|
||||||
|
if len(userIDs) == 0 {
|
||||||
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
|
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
|
keyRaw, _ := json.Marshal(struct {
|
||||||
|
UserIDs []int64 `json:"user_ids"`
|
||||||
|
}{
|
||||||
|
UserIDs: userIDs,
|
||||||
|
})
|
||||||
|
cacheKey := string(keyRaw)
|
||||||
|
if cached, ok := userAttributesBatchCache.Get(cacheKey); ok {
|
||||||
|
c.Header("X-Snapshot-Cache", "hit")
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
|
payload := BatchUserAttributesResponse{Attributes: attrs}
|
||||||
|
userAttributesBatchCache.Set(cacheKey, payload)
|
||||||
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
|
response.Success(c, payload)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -91,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) {
|
|||||||
Search: search,
|
Search: search,
|
||||||
Attributes: parseAttributeFilters(c),
|
Attributes: parseAttributeFilters(c),
|
||||||
}
|
}
|
||||||
|
if raw, ok := c.GetQuery("include_subscriptions"); ok {
|
||||||
|
includeSubscriptions := parseBoolQueryWithDefault(raw, true)
|
||||||
|
filters.IncludeSubscriptions = &includeSubscriptions
|
||||||
|
}
|
||||||
|
|
||||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
|
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
@@ -73,7 +74,23 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
|||||||
page, pageSize := response.ParsePagination(c)
|
page, pageSize := response.ParsePagination(c)
|
||||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
|
|
||||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
|
// Parse filter parameters
|
||||||
|
var filters service.APIKeyListFilters
|
||||||
|
if search := strings.TrimSpace(c.Query("search")); search != "" {
|
||||||
|
if len(search) > 100 {
|
||||||
|
search = search[:100]
|
||||||
|
}
|
||||||
|
filters.Search = search
|
||||||
|
}
|
||||||
|
filters.Status = c.Query("status")
|
||||||
|
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||||
|
gid, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
filters.GroupID = &gid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params, filters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -17,13 +17,14 @@ type CustomMenuItem struct {
|
|||||||
|
|
||||||
// SystemSettings represents the admin settings API response payload.
|
// SystemSettings represents the admin settings API response payload.
|
||||||
type SystemSettings struct {
|
type SystemSettings struct {
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
|
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||||
|
|
||||||
SMTPHost string `json:"smtp_host"`
|
SMTPHost string `json:"smtp_host"`
|
||||||
SMTPPort int `json:"smtp_port"`
|
SMTPPort int `json:"smtp_port"`
|
||||||
@@ -88,28 +89,29 @@ type DefaultSubscriptionSetting struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PublicSettings struct {
|
type PublicSettings struct {
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||||
SiteName string `json:"site_name"`
|
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||||
SiteLogo string `json:"site_logo"`
|
SiteName string `json:"site_name"`
|
||||||
SiteSubtitle string `json:"site_subtitle"`
|
SiteLogo string `json:"site_logo"`
|
||||||
APIBaseURL string `json:"api_base_url"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
ContactInfo string `json:"contact_info"`
|
APIBaseURL string `json:"api_base_url"`
|
||||||
DocURL string `json:"doc_url"`
|
ContactInfo string `json:"contact_info"`
|
||||||
HomeContent string `json:"home_content"`
|
DocURL string `json:"doc_url"`
|
||||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
HomeContent string `json:"home_content"`
|
||||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
Version string `json:"version"`
|
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||||
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ type AdminHandlers struct {
|
|||||||
UserAttribute *admin.UserAttributeHandler
|
UserAttribute *admin.UserAttributeHandler
|
||||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||||
APIKey *admin.AdminAPIKeyHandler
|
APIKey *admin.AdminAPIKeyHandler
|
||||||
|
ScheduledTest *admin.ScheduledTestHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handlers contains all HTTP handlers
|
// Handlers contains all HTTP handlers
|
||||||
|
|||||||
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var handlerStructuredLogCaptureMu sync.Mutex
|
||||||
|
|
||||||
|
type handlerInMemoryLogSink struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
events []*logger.LogEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||||
|
if event == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cloned := *event
|
||||||
|
if event.Fields != nil {
|
||||||
|
cloned.Fields = make(map[string]any, len(event.Fields))
|
||||||
|
for k, v := range event.Fields {
|
||||||
|
cloned.Fields[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.events = append(s.events, &cloned)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
wantLevel := strings.ToLower(strings.TrimSpace(level))
|
||||||
|
for _, ev := range s.events {
|
||||||
|
if ev == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for _, ev := range s.events {
|
||||||
|
if ev == nil || ev.Fields == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) {
|
||||||
|
t.Helper()
|
||||||
|
handlerStructuredLogCaptureMu.Lock()
|
||||||
|
|
||||||
|
err := logger.Init(logger.InitOptions{
|
||||||
|
Level: "debug",
|
||||||
|
Format: "json",
|
||||||
|
ServiceName: "sub2api",
|
||||||
|
Environment: "test",
|
||||||
|
Output: logger.OutputOptions{
|
||||||
|
ToStdout: true,
|
||||||
|
ToFile: false,
|
||||||
|
},
|
||||||
|
Sampling: logger.SamplingOptions{Enabled: false},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sink := &handlerInMemoryLogSink{}
|
||||||
|
logger.SetSink(sink)
|
||||||
|
return sink, func() {
|
||||||
|
logger.SetSink(nil)
|
||||||
|
handlerStructuredLogCaptureMu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOpenAIRemoteCompactPath(t *testing.T) {
|
||||||
|
require.False(t, isOpenAIRemoteCompactPath(nil))
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||||
|
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||||
|
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil)
|
||||||
|
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||||
|
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
require.False(t, isOpenAIRemoteCompactPath(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
logSink, restore := captureHandlerStructuredLog(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||||
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||||
|
c.Set(opsModelKey, "gpt-5.3-codex")
|
||||||
|
c.Set(opsAccountIDKey, int64(123))
|
||||||
|
c.Header("x-request-id", "rid-compact-ok")
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond))
|
||||||
|
|
||||||
|
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("status_code", "200"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("account_id", "123"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
logSink, restore := captureHandlerStructuredLog(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
|
||||||
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||||
|
c.Status(http.StatusBadGateway)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||||
|
|
||||||
|
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("status_code", "502"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("path", "/responses/compact"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
logSink, restore := captureHandlerStructuredLog(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||||
|
|
||||||
|
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||||
|
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
logSink, restore := captureHandlerStructuredLog(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
h.Responses(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||||
|
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("status_code", "401"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||||
|
}
|
||||||
@@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
|
|||||||
errorPassthroughService *service.ErrorPassthroughService
|
errorPassthroughService *service.ErrorPassthroughService
|
||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
maxAccountSwitches int
|
maxAccountSwitches int
|
||||||
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||||
@@ -61,6 +62,7 @@ func NewOpenAIGatewayHandler(
|
|||||||
errorPassthroughService: errorPassthroughService,
|
errorPassthroughService: errorPassthroughService,
|
||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||||
maxAccountSwitches: maxAccountSwitches,
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,6 +72,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||||||
streamStarted := false
|
streamStarted := false
|
||||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||||
|
compactStartedAt := time.Now()
|
||||||
|
defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt)
|
||||||
setOpenAIClientTransportHTTP(c)
|
setOpenAIClientTransportHTTP(c)
|
||||||
|
|
||||||
requestStart := time.Now()
|
requestStart := time.Now()
|
||||||
@@ -340,6 +344,86 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOpenAIRemoteCompactPath(c *gin.Context) bool {
|
||||||
|
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
|
||||||
|
return strings.HasSuffix(normalizedPath, "/responses/compact")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) {
|
||||||
|
if !isOpenAIRemoteCompactPath(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
path string
|
||||||
|
status int
|
||||||
|
)
|
||||||
|
if c != nil {
|
||||||
|
if c.Request != nil {
|
||||||
|
ctx = c.Request.Context()
|
||||||
|
if c.Request.URL != nil {
|
||||||
|
path = strings.TrimSpace(c.Request.URL.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Writer != nil {
|
||||||
|
status = c.Writer.Status()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outcome := "failed"
|
||||||
|
if status >= 200 && status < 300 {
|
||||||
|
outcome = "succeeded"
|
||||||
|
}
|
||||||
|
latencyMs := time.Since(startedAt).Milliseconds()
|
||||||
|
if latencyMs < 0 {
|
||||||
|
latencyMs = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := []zap.Field{
|
||||||
|
zap.String("component", "handler.openai_gateway.responses"),
|
||||||
|
zap.Bool("remote_compact", true),
|
||||||
|
zap.String("compact_outcome", outcome),
|
||||||
|
zap.Int("status_code", status),
|
||||||
|
zap.Int64("latency_ms", latencyMs),
|
||||||
|
zap.String("path", path),
|
||||||
|
zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI),
|
||||||
|
}
|
||||||
|
|
||||||
|
if c != nil {
|
||||||
|
if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" {
|
||||||
|
fields = append(fields, zap.String("request_user_agent", userAgent))
|
||||||
|
}
|
||||||
|
if v, ok := c.Get(opsModelKey); ok {
|
||||||
|
if model, ok := v.(string); ok && strings.TrimSpace(model) != "" {
|
||||||
|
fields = append(fields, zap.String("request_model", strings.TrimSpace(model)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := c.Get(opsAccountIDKey); ok {
|
||||||
|
if accountID, ok := v.(int64); ok && accountID > 0 {
|
||||||
|
fields = append(fields, zap.Int64("account_id", accountID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Writer != nil {
|
||||||
|
if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" {
|
||||||
|
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||||
|
} else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" {
|
||||||
|
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log := logger.FromContext(ctx).With(fields...)
|
||||||
|
if outcome == "succeeded" {
|
||||||
|
log.Info("codex.remote_compact.succeeded")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Warn("codex.remote_compact.failed")
|
||||||
|
}
|
||||||
|
|
||||||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -32,27 +32,28 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.PublicSettings{
|
response.Success(c, dto.PublicSettings{
|
||||||
RegistrationEnabled: settings.RegistrationEnabled,
|
RegistrationEnabled: settings.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TurnstileEnabled: settings.TurnstileEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
TurnstileEnabled: settings.TurnstileEnabled,
|
||||||
SiteName: settings.SiteName,
|
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||||
SiteLogo: settings.SiteLogo,
|
SiteName: settings.SiteName,
|
||||||
SiteSubtitle: settings.SiteSubtitle,
|
SiteLogo: settings.SiteLogo,
|
||||||
APIBaseURL: settings.APIBaseURL,
|
SiteSubtitle: settings.SiteSubtitle,
|
||||||
ContactInfo: settings.ContactInfo,
|
APIBaseURL: settings.APIBaseURL,
|
||||||
DocURL: settings.DocURL,
|
ContactInfo: settings.ContactInfo,
|
||||||
HomeContent: settings.HomeContent,
|
DocURL: settings.DocURL,
|
||||||
HideCcsImportButton: settings.HideCcsImportButton,
|
HomeContent: settings.HomeContent,
|
||||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
HideCcsImportButton: settings.HideCcsImportButton,
|
||||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||||
SoraClientEnabled: settings.SoraClientEnabled,
|
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||||
Version: h.version,
|
SoraClientEnabled: settings.SoraClientEnabled,
|
||||||
|
Version: h.version,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -996,7 +996,7 @@ func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*se
|
|||||||
}
|
}
|
||||||
func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
|
func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
|
||||||
func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
|
func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
|
||||||
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ func ProvideAdminHandlers(
|
|||||||
userAttributeHandler *admin.UserAttributeHandler,
|
userAttributeHandler *admin.UserAttributeHandler,
|
||||||
errorPassthroughHandler *admin.ErrorPassthroughHandler,
|
errorPassthroughHandler *admin.ErrorPassthroughHandler,
|
||||||
apiKeyHandler *admin.AdminAPIKeyHandler,
|
apiKeyHandler *admin.AdminAPIKeyHandler,
|
||||||
|
scheduledTestHandler *admin.ScheduledTestHandler,
|
||||||
) *AdminHandlers {
|
) *AdminHandlers {
|
||||||
return &AdminHandlers{
|
return &AdminHandlers{
|
||||||
Dashboard: dashboardHandler,
|
Dashboard: dashboardHandler,
|
||||||
@@ -53,6 +54,7 @@ func ProvideAdminHandlers(
|
|||||||
UserAttribute: userAttributeHandler,
|
UserAttribute: userAttributeHandler,
|
||||||
ErrorPassthrough: errorPassthroughHandler,
|
ErrorPassthrough: errorPassthroughHandler,
|
||||||
APIKey: apiKeyHandler,
|
APIKey: apiKeyHandler,
|
||||||
|
ScheduledTest: scheduledTestHandler,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,6 +143,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
admin.NewUserAttributeHandler,
|
admin.NewUserAttributeHandler,
|
||||||
admin.NewErrorPassthroughHandler,
|
admin.NewErrorPassthroughHandler,
|
||||||
admin.NewAdminAPIKeyHandler,
|
admin.NewAdminAPIKeyHandler,
|
||||||
|
admin.NewScheduledTestHandler,
|
||||||
|
|
||||||
// AdminHandlers and Handlers constructors
|
// AdminHandlers and Handlers constructors
|
||||||
ProvideAdminHandlers,
|
ProvideAdminHandlers,
|
||||||
|
|||||||
@@ -57,25 +57,28 @@ type DashboardStats struct {
|
|||||||
|
|
||||||
// TrendDataPoint represents a single point in trend data
|
// TrendDataPoint represents a single point in trend data
|
||||||
type TrendDataPoint struct {
|
type TrendDataPoint struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
InputTokens int64 `json:"input_tokens"`
|
InputTokens int64 `json:"input_tokens"`
|
||||||
OutputTokens int64 `json:"output_tokens"`
|
OutputTokens int64 `json:"output_tokens"`
|
||||||
CacheTokens int64 `json:"cache_tokens"`
|
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelStat represents usage statistics for a single model
|
// ModelStat represents usage statistics for a single model
|
||||||
type ModelStat struct {
|
type ModelStat struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
InputTokens int64 `json:"input_tokens"`
|
InputTokens int64 `json:"input_tokens"`
|
||||||
OutputTokens int64 `json:"output_tokens"`
|
OutputTokens int64 `json:"output_tokens"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
// GroupStat represents usage statistics for a single group
|
// GroupStat represents usage statistics for a single group
|
||||||
@@ -154,6 +157,8 @@ type UsageLogFilters struct {
|
|||||||
BillingType *int8
|
BillingType *int8
|
||||||
StartTime *time.Time
|
StartTime *time.Time
|
||||||
EndTime *time.Time
|
EndTime *time.Time
|
||||||
|
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
|
||||||
|
ExactTotal bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageStats represents usage statistics
|
// UsageStats represents usage statistics
|
||||||
|
|||||||
@@ -437,6 +437,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
|||||||
switch status {
|
switch status {
|
||||||
case "rate_limited":
|
case "rate_limited":
|
||||||
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
|
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
|
||||||
|
case "temp_unschedulable":
|
||||||
|
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
|
||||||
|
col := s.C("temp_unschedulable_until")
|
||||||
|
s.Where(entsql.And(
|
||||||
|
entsql.Not(entsql.IsNull(col)),
|
||||||
|
entsql.GT(col, entsql.Expr("NOW()")),
|
||||||
|
))
|
||||||
|
}))
|
||||||
default:
|
default:
|
||||||
q = q.Where(dbaccount.StatusEQ(status))
|
q = q.Where(dbaccount.StatusEQ(status))
|
||||||
}
|
}
|
||||||
@@ -640,7 +648,17 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
|||||||
SetStatus(service.StatusActive).
|
SetStatus(service.StatusActive).
|
||||||
SetErrorMessage("").
|
SetErrorMessage("").
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 清除临时不可调度状态,重置 401 升级链
|
||||||
|
_, _ = r.sql.ExecContext(ctx, `
|
||||||
|
UPDATE accounts
|
||||||
|
SET temp_unschedulable_until = NULL,
|
||||||
|
temp_unschedulable_reason = NULL
|
||||||
|
WHERE id = $1 AND deleted_at IS NULL
|
||||||
|
`, id)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||||
|
|||||||
@@ -281,9 +281,27 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||||
|
|
||||||
|
// Apply filters
|
||||||
|
if filters.Search != "" {
|
||||||
|
q = q.Where(apikey.Or(
|
||||||
|
apikey.NameContainsFold(filters.Search),
|
||||||
|
apikey.KeyContainsFold(filters.Search),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
if filters.Status != "" {
|
||||||
|
q = q.Where(apikey.StatusEQ(filters.Status))
|
||||||
|
}
|
||||||
|
if filters.GroupID != nil {
|
||||||
|
if *filters.GroupID == 0 {
|
||||||
|
q = q.Where(apikey.GroupIDIsNil())
|
||||||
|
} else {
|
||||||
|
q = q.Where(apikey.GroupIDEQ(*filters.GroupID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
total, err := q.Count(ctx)
|
total, err := q.Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
|
|||||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||||
|
|
||||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
|
||||||
s.Require().NoError(err, "ListByUserID")
|
s.Require().NoError(err, "ListByUserID")
|
||||||
s.Require().Len(keys, 2)
|
s.Require().Len(keys, 2)
|
||||||
s.Require().Equal(int64(2), page.Total)
|
s.Require().Equal(int64(2), page.Total)
|
||||||
@@ -170,7 +170,7 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
|
|||||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}, service.APIKeyListFilters{})
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(keys, 2)
|
s.Require().Len(keys, 2)
|
||||||
s.Require().Equal(int64(5), page.Total)
|
s.Require().Equal(int64(5), page.Total)
|
||||||
@@ -314,7 +314,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
|||||||
s.Require().Equal(service.StatusDisabled, got2.Status)
|
s.Require().Equal(service.StatusDisabled, got2.Status)
|
||||||
s.Require().Nil(got2.GroupID)
|
s.Require().Nil(got2.GroupID)
|
||||||
|
|
||||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
|
||||||
s.Require().NoError(err, "ListByUserID")
|
s.Require().NoError(err, "ListByUserID")
|
||||||
s.Require().Equal(int64(1), page.Total)
|
s.Require().Equal(int64(1), page.Total)
|
||||||
s.Require().Len(keys, 1)
|
s.Require().Len(keys, 1)
|
||||||
|
|||||||
183
backend/internal/repository/scheduled_test_repo.go
Normal file
183
backend/internal/repository/scheduled_test_repo.go
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Plan Repository ---
|
||||||
|
|
||||||
|
type scheduledTestPlanRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanRepository {
|
||||||
|
return &scheduledTestPlanRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||||
|
row := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
|
||||||
|
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||||
|
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||||
|
return scanPlan(row)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||||
|
row := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||||
|
FROM scheduled_test_plans WHERE id = $1
|
||||||
|
`, id)
|
||||||
|
return scanPlan(row)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||||
|
FROM scheduled_test_plans WHERE account_id = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
return scanPlans(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||||
|
FROM scheduled_test_plans
|
||||||
|
WHERE enabled = true AND next_run_at <= $1
|
||||||
|
ORDER BY next_run_at ASC
|
||||||
|
`, now)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
return scanPlans(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||||
|
row := r.db.QueryRowContext(ctx, `
|
||||||
|
UPDATE scheduled_test_plans
|
||||||
|
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
|
||||||
|
WHERE id = $1
|
||||||
|
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||||
|
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||||
|
return scanPlan(row)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduledTestPlanRepository) Delete(ctx context.Context, id int64) error {
|
||||||
|
_, err := r.db.ExecContext(ctx, `DELETE FROM scheduled_test_plans WHERE id = $1`, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduledTestPlanRepository) UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error {
|
||||||
|
_, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE scheduled_test_plans SET last_run_at = $2, next_run_at = $3, updated_at = NOW() WHERE id = $1
|
||||||
|
`, id, lastRunAt, nextRunAt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Result Repository ---
|
||||||
|
|
||||||
|
type scheduledTestResultRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScheduledTestResultRepository(db *sql.DB) service.ScheduledTestResultRepository {
|
||||||
|
return &scheduledTestResultRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduledTestResultRepository) Create(ctx context.Context, result *service.ScheduledTestResult) (*service.ScheduledTestResult, error) {
|
||||||
|
row := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO scheduled_test_results (plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
|
||||||
|
RETURNING id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
|
||||||
|
`, result.PlanID, result.Status, result.ResponseText, result.ErrorMessage, result.LatencyMs, result.StartedAt, result.FinishedAt)
|
||||||
|
|
||||||
|
out := &service.ScheduledTestResult{}
|
||||||
|
if err := row.Scan(
|
||||||
|
&out.ID, &out.PlanID, &out.Status, &out.ResponseText, &out.ErrorMessage,
|
||||||
|
&out.LatencyMs, &out.StartedAt, &out.FinishedAt, &out.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduledTestResultRepository) ListByPlanID(ctx context.Context, planID int64, limit int) ([]*service.ScheduledTestResult, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
|
||||||
|
FROM scheduled_test_results
|
||||||
|
WHERE plan_id = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $2
|
||||||
|
`, planID, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var results []*service.ScheduledTestResult
|
||||||
|
for rows.Next() {
|
||||||
|
r := &service.ScheduledTestResult{}
|
||||||
|
if err := rows.Scan(
|
||||||
|
&r.ID, &r.PlanID, &r.Status, &r.ResponseText, &r.ErrorMessage,
|
||||||
|
&r.LatencyMs, &r.StartedAt, &r.FinishedAt, &r.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results = append(results, r)
|
||||||
|
}
|
||||||
|
return results, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduledTestResultRepository) PruneOldResults(ctx context.Context, planID int64, keepCount int) error {
|
||||||
|
_, err := r.db.ExecContext(ctx, `
|
||||||
|
DELETE FROM scheduled_test_results
|
||||||
|
WHERE id IN (
|
||||||
|
SELECT id FROM (
|
||||||
|
SELECT id, ROW_NUMBER() OVER (PARTITION BY plan_id ORDER BY created_at DESC) AS rn
|
||||||
|
FROM scheduled_test_results
|
||||||
|
WHERE plan_id = $1
|
||||||
|
) ranked
|
||||||
|
WHERE rn > $2
|
||||||
|
)
|
||||||
|
`, planID, keepCount)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- scan helpers ---
|
||||||
|
|
||||||
|
type scannable interface {
|
||||||
|
Scan(dest ...any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||||
|
p := &service.ScheduledTestPlan{}
|
||||||
|
if err := row.Scan(
|
||||||
|
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
|
||||||
|
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanPlans(rows *sql.Rows) ([]*service.ScheduledTestPlan, error) {
|
||||||
|
var plans []*service.ScheduledTestPlan
|
||||||
|
for rows.Next() {
|
||||||
|
p, err := scanPlan(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
plans = append(plans, p)
|
||||||
|
}
|
||||||
|
return plans, rows.Err()
|
||||||
|
}
|
||||||
@@ -122,7 +122,7 @@ func (s *SettingRepoSuite) TestSet_EmptyValue() {
|
|||||||
func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
|
func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
|
||||||
// 模拟保存站点设置,部分字段有值,部分字段为空
|
// 模拟保存站点设置,部分字段有值,部分字段为空
|
||||||
settings := map[string]string{
|
settings := map[string]string{
|
||||||
"site_name": "AICodex2API",
|
"site_name": "Sub2api",
|
||||||
"site_subtitle": "Subscription to API",
|
"site_subtitle": "Subscription to API",
|
||||||
"site_logo": "", // 用户未上传Logo
|
"site_logo": "", // 用户未上传Logo
|
||||||
"api_base_url": "", // 用户未设置API地址
|
"api_base_url": "", // 用户未设置API地址
|
||||||
@@ -136,7 +136,7 @@ func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
|
|||||||
result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"})
|
result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"})
|
||||||
s.Require().NoError(err, "GetMultiple after SetMultiple with empty values")
|
s.Require().NoError(err, "GetMultiple after SetMultiple with empty values")
|
||||||
|
|
||||||
s.Require().Equal("AICodex2API", result["site_name"])
|
s.Require().Equal("Sub2api", result["site_name"])
|
||||||
s.Require().Equal("Subscription to API", result["site_subtitle"])
|
s.Require().Equal("Subscription to API", result["site_subtitle"])
|
||||||
s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved")
|
s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved")
|
||||||
s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved")
|
s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved")
|
||||||
|
|||||||
@@ -1363,7 +1363,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
|||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||||
|
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||||
@@ -1401,6 +1402,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
|||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||||
|
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||||
|
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||||
@@ -1473,7 +1476,16 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
|||||||
}
|
}
|
||||||
|
|
||||||
whereClause := buildWhere(conditions)
|
whereClause := buildWhere(conditions)
|
||||||
logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params)
|
var (
|
||||||
|
logs []service.UsageLog
|
||||||
|
page *pagination.PaginationResult
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if shouldUseFastUsageLogTotal(filters) {
|
||||||
|
logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params)
|
||||||
|
} else {
|
||||||
|
logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -1484,17 +1496,45 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
|||||||
return logs, page, nil
|
return logs, page, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool {
|
||||||
|
if filters.ExactTotal {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// 强选择过滤下记录集通常较小,保留精确总数。
|
||||||
|
return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0
|
||||||
|
}
|
||||||
|
|
||||||
// UsageStats represents usage statistics
|
// UsageStats represents usage statistics
|
||||||
type UsageStats = usagestats.UsageStats
|
type UsageStats = usagestats.UsageStats
|
||||||
|
|
||||||
// BatchUserUsageStats represents usage stats for a single user
|
// BatchUserUsageStats represents usage stats for a single user
|
||||||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||||
|
|
||||||
|
func normalizePositiveInt64IDs(ids []int64) []int64 {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
seen := make(map[int64]struct{}, len(ids))
|
||||||
|
out := make([]int64, 0, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
if id <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
|
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
|
||||||
// If startTime is zero, defaults to 30 days ago.
|
// If startTime is zero, defaults to 30 days ago.
|
||||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
|
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
|
||||||
result := make(map[int64]*BatchUserUsageStats)
|
result := make(map[int64]*BatchUserUsageStats)
|
||||||
if len(userIDs) == 0 {
|
normalizedUserIDs := normalizePositiveInt64IDs(userIDs)
|
||||||
|
if len(normalizedUserIDs) == 0 {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1506,58 +1546,36 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
|||||||
endTime = time.Now()
|
endTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, id := range userIDs {
|
for _, id := range normalizedUserIDs {
|
||||||
result[id] = &BatchUserUsageStats{UserID: id}
|
result[id] = &BatchUserUsageStats{UserID: id}
|
||||||
}
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
SELECT
|
||||||
|
user_id,
|
||||||
|
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
|
||||||
|
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
WHERE user_id = ANY($1)
|
||||||
|
AND created_at >= LEAST($2, $4)
|
||||||
GROUP BY user_id
|
GROUP BY user_id
|
||||||
`
|
`
|
||||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
|
today := timezone.Today()
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var userID int64
|
var userID int64
|
||||||
var total float64
|
var total float64
|
||||||
if err := rows.Scan(&userID, &total); err != nil {
|
var todayTotal float64
|
||||||
|
if err := rows.Scan(&userID, &total, &todayTotal); err != nil {
|
||||||
_ = rows.Close()
|
_ = rows.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if stats, ok := result[userID]; ok {
|
if stats, ok := result[userID]; ok {
|
||||||
stats.TotalActualCost = total
|
stats.TotalActualCost = total
|
||||||
}
|
stats.TodayActualCost = todayTotal
|
||||||
}
|
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
today := timezone.Today()
|
|
||||||
todayQuery := `
|
|
||||||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost
|
|
||||||
FROM usage_logs
|
|
||||||
WHERE user_id = ANY($1) AND created_at >= $2
|
|
||||||
GROUP BY user_id
|
|
||||||
`
|
|
||||||
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for rows.Next() {
|
|
||||||
var userID int64
|
|
||||||
var total float64
|
|
||||||
if err := rows.Scan(&userID, &total); err != nil {
|
|
||||||
_ = rows.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if stats, ok := result[userID]; ok {
|
|
||||||
stats.TodayActualCost = total
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := rows.Close(); err != nil {
|
if err := rows.Close(); err != nil {
|
||||||
@@ -1577,7 +1595,8 @@ type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
|||||||
// If startTime is zero, defaults to 30 days ago.
|
// If startTime is zero, defaults to 30 days ago.
|
||||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
|
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||||
if len(apiKeyIDs) == 0 {
|
normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs)
|
||||||
|
if len(normalizedAPIKeyIDs) == 0 {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1589,58 +1608,36 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
|||||||
endTime = time.Now()
|
endTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, id := range apiKeyIDs {
|
for _, id := range normalizedAPIKeyIDs {
|
||||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||||
}
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
SELECT
|
||||||
|
api_key_id,
|
||||||
|
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
|
||||||
|
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
WHERE api_key_id = ANY($1)
|
||||||
|
AND created_at >= LEAST($2, $4)
|
||||||
GROUP BY api_key_id
|
GROUP BY api_key_id
|
||||||
`
|
`
|
||||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
|
today := timezone.Today()
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var apiKeyID int64
|
var apiKeyID int64
|
||||||
var total float64
|
var total float64
|
||||||
if err := rows.Scan(&apiKeyID, &total); err != nil {
|
var todayTotal float64
|
||||||
|
if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil {
|
||||||
_ = rows.Close()
|
_ = rows.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if stats, ok := result[apiKeyID]; ok {
|
if stats, ok := result[apiKeyID]; ok {
|
||||||
stats.TotalActualCost = total
|
stats.TotalActualCost = total
|
||||||
}
|
stats.TodayActualCost = todayTotal
|
||||||
}
|
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
today := timezone.Today()
|
|
||||||
todayQuery := `
|
|
||||||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost
|
|
||||||
FROM usage_logs
|
|
||||||
WHERE api_key_id = ANY($1) AND created_at >= $2
|
|
||||||
GROUP BY api_key_id
|
|
||||||
`
|
|
||||||
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for rows.Next() {
|
|
||||||
var apiKeyID int64
|
|
||||||
var total float64
|
|
||||||
if err := rows.Scan(&apiKeyID, &total); err != nil {
|
|
||||||
_ = rows.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if stats, ok := result[apiKeyID]; ok {
|
|
||||||
stats.TodayActualCost = total
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := rows.Close(); err != nil {
|
if err := rows.Close(); err != nil {
|
||||||
@@ -1670,7 +1667,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
|||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||||
|
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||||
@@ -1753,7 +1751,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
|
|||||||
total_requests as requests,
|
total_requests as requests,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
output_tokens,
|
output_tokens,
|
||||||
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
||||||
total_cost as cost,
|
total_cost as cost,
|
||||||
actual_cost
|
actual_cost
|
||||||
@@ -1768,7 +1767,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
|
|||||||
total_requests as requests,
|
total_requests as requests,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
output_tokens,
|
output_tokens,
|
||||||
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
||||||
total_cost as cost,
|
total_cost as cost,
|
||||||
actual_cost
|
actual_cost
|
||||||
@@ -1812,6 +1812,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
|||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||||
|
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||||
|
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
%s
|
%s
|
||||||
@@ -2245,6 +2247,35 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
|
|||||||
return logs, paginationResultFromTotal(total, params), nil
|
return logs, paginationResultFromTotal(total, params), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
limit := params.Limit()
|
||||||
|
offset := params.Offset()
|
||||||
|
|
||||||
|
limitPos := len(args) + 1
|
||||||
|
offsetPos := len(args) + 2
|
||||||
|
listArgs := append(append([]any{}, args...), limit+1, offset)
|
||||||
|
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
|
||||||
|
|
||||||
|
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hasMore := false
|
||||||
|
if len(logs) > limit {
|
||||||
|
hasMore = true
|
||||||
|
logs = logs[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
|
total := int64(offset) + int64(len(logs))
|
||||||
|
if hasMore {
|
||||||
|
// 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。
|
||||||
|
total = int64(offset) + int64(limit) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return logs, paginationResultFromTotal(total, params), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
|
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
|
||||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -2599,7 +2630,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
|
|||||||
&row.Requests,
|
&row.Requests,
|
||||||
&row.InputTokens,
|
&row.InputTokens,
|
||||||
&row.OutputTokens,
|
&row.OutputTokens,
|
||||||
&row.CacheTokens,
|
&row.CacheCreationTokens,
|
||||||
|
&row.CacheReadTokens,
|
||||||
&row.TotalTokens,
|
&row.TotalTokens,
|
||||||
&row.Cost,
|
&row.Cost,
|
||||||
&row.ActualCost,
|
&row.ActualCost,
|
||||||
@@ -2623,6 +2655,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
|
|||||||
&row.Requests,
|
&row.Requests,
|
||||||
&row.InputTokens,
|
&row.InputTokens,
|
||||||
&row.OutputTokens,
|
&row.OutputTokens,
|
||||||
|
&row.CacheCreationTokens,
|
||||||
|
&row.CacheReadTokens,
|
||||||
&row.TotalTokens,
|
&row.TotalTokens,
|
||||||
&row.Cost,
|
&row.Cost,
|
||||||
&row.ActualCost,
|
&row.ActualCost,
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
|
|||||||
filters := usagestats.UsageLogFilters{
|
filters := usagestats.UsageLogFilters{
|
||||||
RequestType: &requestType,
|
RequestType: &requestType,
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
|
ExactTotal: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
||||||
@@ -124,7 +125,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin
|
|||||||
|
|
||||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
|
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
|
||||||
WithArgs(start, end, requestType).
|
WithArgs(start, end, requestType).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
|
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||||
|
|
||||||
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
|
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -143,7 +144,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
|
|||||||
|
|
||||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
||||||
WithArgs(start, end, requestType).
|
WithArgs(start, end, requestType).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
|
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||||
|
|
||||||
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
|
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -243,21 +243,24 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
|||||||
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Batch load active subscriptions with groups to avoid N+1.
|
shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions
|
||||||
subs, err := r.client.UserSubscription.Query().
|
if shouldLoadSubscriptions {
|
||||||
Where(
|
// Batch load active subscriptions with groups to avoid N+1.
|
||||||
usersubscription.UserIDIn(userIDs...),
|
subs, err := r.client.UserSubscription.Query().
|
||||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
Where(
|
||||||
).
|
usersubscription.UserIDIn(userIDs...),
|
||||||
WithGroup().
|
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||||
All(ctx)
|
).
|
||||||
if err != nil {
|
WithGroup().
|
||||||
return nil, nil, err
|
All(ctx)
|
||||||
}
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
for i := range subs {
|
for i := range subs {
|
||||||
if u, ok := userMap[subs[i].UserID]; ok {
|
if u, ok := userMap[subs[i].UserID]; ok {
|
||||||
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
|
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,9 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAPIKeyRepository,
|
NewAPIKeyRepository,
|
||||||
NewGroupRepository,
|
NewGroupRepository,
|
||||||
NewAccountRepository,
|
NewAccountRepository,
|
||||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||||
|
NewScheduledTestPlanRepository, // 定时测试计划仓储
|
||||||
|
NewScheduledTestResultRepository, // 定时测试结果仓储
|
||||||
NewProxyRepository,
|
NewProxyRepository,
|
||||||
NewRedeemCodeRepository,
|
NewRedeemCodeRepository,
|
||||||
NewPromoCodeRepository,
|
NewPromoCodeRepository,
|
||||||
|
|||||||
@@ -446,9 +446,10 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
setup: func(t *testing.T, deps *contractDeps) {
|
setup: func(t *testing.T, deps *contractDeps) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
deps.settingRepo.SetAll(map[string]string{
|
deps.settingRepo.SetAll(map[string]string{
|
||||||
service.SettingKeyRegistrationEnabled: "true",
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
service.SettingKeyEmailVerifyEnabled: "false",
|
service.SettingKeyEmailVerifyEnabled: "false",
|
||||||
service.SettingKeyPromoCodeEnabled: "true",
|
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||||
|
service.SettingKeyPromoCodeEnabled: "true",
|
||||||
|
|
||||||
service.SettingKeySMTPHost: "smtp.example.com",
|
service.SettingKeySMTPHost: "smtp.example.com",
|
||||||
service.SettingKeySMTPPort: "587",
|
service.SettingKeySMTPPort: "587",
|
||||||
@@ -487,6 +488,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"data": {
|
"data": {
|
||||||
"registration_enabled": true,
|
"registration_enabled": true,
|
||||||
"email_verify_enabled": false,
|
"email_verify_enabled": false,
|
||||||
|
"registration_email_suffix_whitelist": [],
|
||||||
"promo_code_enabled": true,
|
"promo_code_enabled": true,
|
||||||
"password_reset_enabled": false,
|
"password_reset_enabled": false,
|
||||||
"totp_enabled": false,
|
"totp_enabled": false,
|
||||||
@@ -1411,7 +1413,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||||
ids := make([]int64, 0, len(r.byID))
|
ids := make([]int64, 0, len(r.byID))
|
||||||
for id := range r.byID {
|
for id := range r.byID {
|
||||||
if r.byID[id].UserID == userID {
|
if r.byID[id].UserID == userID {
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
|||||||
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||||
|
|||||||
@@ -537,7 +537,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
|||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -78,6 +78,9 @@ func RegisterAdminRoutes(
|
|||||||
|
|
||||||
// API Key 管理
|
// API Key 管理
|
||||||
registerAdminAPIKeyRoutes(admin, h)
|
registerAdminAPIKeyRoutes(admin, h)
|
||||||
|
|
||||||
|
// 定时测试计划
|
||||||
|
registerScheduledTestRoutes(admin, h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,6 +171,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
|
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
|
||||||
|
|
||||||
// Dashboard (vNext - raw path for MVP)
|
// Dashboard (vNext - raw path for MVP)
|
||||||
|
ops.GET("/dashboard/snapshot-v2", h.Admin.Ops.GetDashboardSnapshotV2)
|
||||||
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
|
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
|
||||||
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
|
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
|
||||||
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
|
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
|
||||||
@@ -180,6 +184,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
dashboard := admin.Group("/dashboard")
|
dashboard := admin.Group("/dashboard")
|
||||||
{
|
{
|
||||||
|
dashboard.GET("/snapshot-v2", h.Admin.Dashboard.GetSnapshotV2)
|
||||||
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
||||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||||
@@ -476,6 +481,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
|
plans := admin.Group("/scheduled-test-plans")
|
||||||
|
{
|
||||||
|
plans.POST("", h.Admin.ScheduledTest.Create)
|
||||||
|
plans.PUT("/:id", h.Admin.ScheduledTest.Update)
|
||||||
|
plans.DELETE("/:id", h.Admin.ScheduledTest.Delete)
|
||||||
|
plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults)
|
||||||
|
}
|
||||||
|
// Nested under accounts
|
||||||
|
admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount)
|
||||||
|
}
|
||||||
|
|
||||||
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
rules := admin.Group("/error-passthrough-rules")
|
rules := admin.Group("/error-passthrough-rules")
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -853,15 +853,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
OpenAIWSIngressModeOff = "off"
|
OpenAIWSIngressModeOff = "off"
|
||||||
OpenAIWSIngressModeShared = "shared"
|
OpenAIWSIngressModeShared = "shared"
|
||||||
OpenAIWSIngressModeDedicated = "dedicated"
|
OpenAIWSIngressModeDedicated = "dedicated"
|
||||||
|
OpenAIWSIngressModeCtxPool = "ctx_pool"
|
||||||
|
OpenAIWSIngressModePassthrough = "passthrough"
|
||||||
)
|
)
|
||||||
|
|
||||||
func normalizeOpenAIWSIngressMode(mode string) string {
|
func normalizeOpenAIWSIngressMode(mode string) string {
|
||||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||||
case OpenAIWSIngressModeOff:
|
case OpenAIWSIngressModeOff:
|
||||||
return OpenAIWSIngressModeOff
|
return OpenAIWSIngressModeOff
|
||||||
|
case OpenAIWSIngressModeCtxPool:
|
||||||
|
return OpenAIWSIngressModeCtxPool
|
||||||
|
case OpenAIWSIngressModePassthrough:
|
||||||
|
return OpenAIWSIngressModePassthrough
|
||||||
case OpenAIWSIngressModeShared:
|
case OpenAIWSIngressModeShared:
|
||||||
return OpenAIWSIngressModeShared
|
return OpenAIWSIngressModeShared
|
||||||
case OpenAIWSIngressModeDedicated:
|
case OpenAIWSIngressModeDedicated:
|
||||||
@@ -873,18 +879,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
|
|||||||
|
|
||||||
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
||||||
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
||||||
|
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
|
||||||
|
return OpenAIWSIngressModeCtxPool
|
||||||
|
}
|
||||||
return normalized
|
return normalized
|
||||||
}
|
}
|
||||||
return OpenAIWSIngressModeShared
|
return OpenAIWSIngressModeCtxPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
|
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。
|
||||||
//
|
//
|
||||||
// 优先级:
|
// 优先级:
|
||||||
// 1. 分类型 mode 新字段(string)
|
// 1. 分类型 mode 新字段(string)
|
||||||
// 2. 分类型 enabled 旧字段(bool)
|
// 2. 分类型 enabled 旧字段(bool)
|
||||||
// 3. 兼容 enabled 旧字段(bool)
|
// 3. 兼容 enabled 旧字段(bool)
|
||||||
// 4. defaultMode(非法时回退 shared)
|
// 4. defaultMode(非法时回退 ctx_pool)
|
||||||
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
||||||
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
||||||
if a == nil || !a.IsOpenAI() {
|
if a == nil || !a.IsOpenAI() {
|
||||||
@@ -919,7 +928,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
|||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
if enabled {
|
if enabled {
|
||||||
return OpenAIWSIngressModeShared, true
|
return OpenAIWSIngressModeCtxPool, true
|
||||||
}
|
}
|
||||||
return OpenAIWSIngressModeOff, true
|
return OpenAIWSIngressModeOff, true
|
||||||
}
|
}
|
||||||
@@ -946,6 +955,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
|||||||
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
||||||
return mode
|
return mode
|
||||||
}
|
}
|
||||||
|
// 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。
|
||||||
|
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
|
||||||
|
return OpenAIWSIngressModeCtxPool
|
||||||
|
}
|
||||||
return resolvedDefault
|
return resolvedDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||||
t.Run("default fallback to shared", func(t *testing.T) {
|
t.Run("default fallback to ctx_pool", func(t *testing.T) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Extra: map[string]any{},
|
Extra: map[string]any{},
|
||||||
}
|
}
|
||||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
||||||
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Extra: map[string]any{
|
Extra: map[string]any{
|
||||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||||
"responses_websockets_v2_enabled": false,
|
"responses_websockets_v2_enabled": false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("legacy enabled maps to shared", func(t *testing.T) {
|
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeAPIKey,
|
Type: AccountTypeAPIKey,
|
||||||
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
|||||||
"responses_websockets_v2_enabled": true,
|
"responses_websockets_v2_enabled": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
|
||||||
|
shared := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dedicated := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||||
|
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||||
|
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
|
||||||
|
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
||||||
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
|||||||
"responses_websockets_v2_enabled": true,
|
"responses_websockets_v2_enabled": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("non openai always off", func(t *testing.T) {
|
t.Run("non openai always off", func(t *testing.T) {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -33,7 +34,7 @@ import (
|
|||||||
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||||
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
||||||
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
||||||
@@ -238,7 +239,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||||
}
|
}
|
||||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
|
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages?beta=true"
|
||||||
} else {
|
} else {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||||
}
|
}
|
||||||
@@ -1560,3 +1561,62 @@ func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) er
|
|||||||
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
|
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
|
||||||
return fmt.Errorf("%s", errorMsg)
|
return fmt.Errorf("%s", errorMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RunTestBackground executes an account test in-memory (no real HTTP client),
|
||||||
|
// capturing SSE output via httptest.NewRecorder, then parses the result.
|
||||||
|
func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID int64, modelID string) (*ScheduledTestResult, error) {
|
||||||
|
startedAt := time.Now()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(w)
|
||||||
|
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
||||||
|
|
||||||
|
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
|
||||||
|
|
||||||
|
finishedAt := time.Now()
|
||||||
|
body := w.Body.String()
|
||||||
|
responseText, errMsg := parseTestSSEOutput(body)
|
||||||
|
|
||||||
|
status := "success"
|
||||||
|
if testErr != nil || errMsg != "" {
|
||||||
|
status = "failed"
|
||||||
|
if errMsg == "" && testErr != nil {
|
||||||
|
errMsg = testErr.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ScheduledTestResult{
|
||||||
|
Status: status,
|
||||||
|
ResponseText: responseText,
|
||||||
|
ErrorMessage: errMsg,
|
||||||
|
LatencyMs: finishedAt.Sub(startedAt).Milliseconds(),
|
||||||
|
StartedAt: startedAt,
|
||||||
|
FinishedAt: finishedAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseTestSSEOutput extracts response text and error message from captured SSE output.
|
||||||
|
func parseTestSSEOutput(body string) (responseText, errMsg string) {
|
||||||
|
var texts []string
|
||||||
|
for _, line := range strings.Split(body, "\n") {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||||
|
var event TestEvent
|
||||||
|
if err := json.Unmarshal([]byte(jsonStr), &event); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch event.Type {
|
||||||
|
case "content":
|
||||||
|
if event.Text != "" {
|
||||||
|
texts = append(texts, event.Text)
|
||||||
|
}
|
||||||
|
case "error":
|
||||||
|
errMsg = event.Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
responseText = strings.Join(texts, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -745,7 +745,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
|
|
||||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
|
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string)
|
|||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||||
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||||
|
|||||||
@@ -97,3 +97,10 @@ func (k *APIKey) GetDaysUntilExpiry() int {
|
|||||||
}
|
}
|
||||||
return int(duration.Hours() / 24)
|
return int(duration.Hours() / 24)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// APIKeyListFilters holds optional filtering parameters for listing API keys.
|
||||||
|
type APIKeyListFilters struct {
|
||||||
|
Search string
|
||||||
|
Status string
|
||||||
|
GroupID *int64 // nil=不筛选, 0=无分组, >0=指定分组
|
||||||
|
}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ type APIKeyRepository interface {
|
|||||||
Update(ctx context.Context, key *APIKey) error
|
Update(ctx context.Context, key *APIKey) error
|
||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error)
|
||||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||||
@@ -392,8 +392,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
|||||||
}
|
}
|
||||||
|
|
||||||
// List 获取用户的API Key列表
|
// List 获取用户的API Key列表
|
||||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
|
|||||||
panic("unexpected Delete call")
|
panic("unexpected Delete call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
panic("unexpected ListByUserID call")
|
panic("unexpected ListByUserID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
|
|||||||
|
|
||||||
// 以下是接口要求实现但本测试不关心的方法
|
// 以下是接口要求实现但本测试不关心的方法
|
||||||
|
|
||||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
panic("unexpected ListByUserID call")
|
panic("unexpected ListByUserID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -33,6 +34,7 @@ var (
|
|||||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||||
|
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
|
||||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||||
@@ -115,6 +117,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
|||||||
if isReservedEmail(email) {
|
if isReservedEmail(email) {
|
||||||
return "", nil, ErrEmailReserved
|
return "", nil, ErrEmailReserved
|
||||||
}
|
}
|
||||||
|
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否需要邀请码
|
// 检查是否需要邀请码
|
||||||
var invitationRedeemCode *RedeemCode
|
var invitationRedeemCode *RedeemCode
|
||||||
@@ -241,6 +246,9 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
|||||||
if isReservedEmail(email) {
|
if isReservedEmail(email) {
|
||||||
return ErrEmailReserved
|
return ErrEmailReserved
|
||||||
}
|
}
|
||||||
|
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// 检查邮箱是否已存在
|
// 检查邮箱是否已存在
|
||||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||||
@@ -279,6 +287,9 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
|||||||
if isReservedEmail(email) {
|
if isReservedEmail(email) {
|
||||||
return nil, ErrEmailReserved
|
return nil, ErrEmailReserved
|
||||||
}
|
}
|
||||||
|
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// 检查邮箱是否已存在
|
// 检查邮箱是否已存在
|
||||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||||
@@ -624,6 +635,32 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
|
||||||
|
if s.settingService == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
whitelist := s.settingService.GetRegistrationEmailSuffixWhitelist(ctx)
|
||||||
|
if !IsRegistrationEmailSuffixAllowed(email, whitelist) {
|
||||||
|
return buildEmailSuffixNotAllowedError(whitelist)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildEmailSuffixNotAllowedError(whitelist []string) error {
|
||||||
|
if len(whitelist) == 0 {
|
||||||
|
return ErrEmailSuffixNotAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := strings.Join(whitelist, ", ")
|
||||||
|
return infraerrors.BadRequest(
|
||||||
|
"EMAIL_SUFFIX_NOT_ALLOWED",
|
||||||
|
fmt.Sprintf("email suffix is not allowed, allowed suffixes: %s", allowed),
|
||||||
|
).WithMetadata(map[string]string{
|
||||||
|
"allowed_suffixes": strings.Join(whitelist, ","),
|
||||||
|
"allowed_suffix_count": strconv.Itoa(len(whitelist)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateToken 验证JWT token并返回用户声明
|
// ValidateToken 验证JWT token并返回用户声明
|
||||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -231,6 +232,51 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, ErrEmailReserved)
|
require.ErrorIs(t, err, ErrEmailReserved)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) {
|
||||||
|
repo := &userRepoStub{}
|
||||||
|
service := newAuthService(repo, map[string]string{
|
||||||
|
SettingKeyRegistrationEnabled: "true",
|
||||||
|
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
_, _, err := service.Register(context.Background(), "user@other.com", "password")
|
||||||
|
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
|
||||||
|
appErr := infraerrors.FromError(err)
|
||||||
|
require.Contains(t, appErr.Message, "@example.com")
|
||||||
|
require.Contains(t, appErr.Message, "@company.com")
|
||||||
|
require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason)
|
||||||
|
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
|
||||||
|
require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) {
|
||||||
|
repo := &userRepoStub{nextID: 8}
|
||||||
|
service := newAuthService(repo, map[string]string{
|
||||||
|
SettingKeyRegistrationEnabled: "true",
|
||||||
|
SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
_, user, err := service.Register(context.Background(), "user@example.com", "password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
require.Equal(t, int64(8), user.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) {
|
||||||
|
repo := &userRepoStub{}
|
||||||
|
service := newAuthService(repo, map[string]string{
|
||||||
|
SettingKeyRegistrationEnabled: "true",
|
||||||
|
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
err := service.SendVerifyCode(context.Background(), "user@other.com")
|
||||||
|
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
|
||||||
|
appErr := infraerrors.FromError(err)
|
||||||
|
require.Contains(t, appErr.Message, "@example.com")
|
||||||
|
require.Contains(t, appErr.Message, "@company.com")
|
||||||
|
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthService_Register_CreateError(t *testing.T) {
|
func TestAuthService_Register_CreateError(t *testing.T) {
|
||||||
repo := &userRepoStub{createErr: errors.New("create failed")}
|
repo := &userRepoStub{createErr: errors.New("create failed")}
|
||||||
service := newAuthService(repo, map[string]string{
|
service := newAuthService(repo, map[string]string{
|
||||||
@@ -402,7 +448,7 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
|
|||||||
repo := &userRepoStub{nextID: 42}
|
repo := &userRepoStub{nextID: 42}
|
||||||
assigner := &defaultSubscriptionAssignerStub{}
|
assigner := &defaultSubscriptionAssignerStub{}
|
||||||
service := newAuthService(repo, map[string]string{
|
service := newAuthService(repo, map[string]string{
|
||||||
SettingKeyRegistrationEnabled: "true",
|
SettingKeyRegistrationEnabled: "true",
|
||||||
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
|
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
|
||||||
}, nil)
|
}, nil)
|
||||||
service.defaultSubAssigner = assigner
|
service.defaultSubAssigner = assigner
|
||||||
|
|||||||
@@ -74,11 +74,12 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
|||||||
// Setting keys
|
// Setting keys
|
||||||
const (
|
const (
|
||||||
// 注册设置
|
// 注册设置
|
||||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单(JSON 数组)
|
||||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||||
|
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||||
|
|||||||
@@ -88,6 +88,49 @@ func TestCheckErrorPolicy(t *testing.T) {
|
|||||||
body: []byte(`overloaded service`),
|
body: []byte(`overloaded service`),
|
||||||
expected: ErrorPolicyTempUnscheduled,
|
expected: ErrorPolicyTempUnscheduled,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "temp_unschedulable_401_first_hit_returns_temp_unscheduled",
|
||||||
|
account: &Account{
|
||||||
|
ID: 14,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(401),
|
||||||
|
"keywords": []any{"unauthorized"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 401,
|
||||||
|
body: []byte(`unauthorized`),
|
||||||
|
expected: ErrorPolicyTempUnscheduled,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
|
||||||
|
account: &Account{
|
||||||
|
ID: 15,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(401),
|
||||||
|
"keywords": []any{"unauthorized"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 401,
|
||||||
|
body: []byte(`unauthorized`),
|
||||||
|
expected: ErrorPolicyNone,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "temp_unschedulable_body_miss_returns_none",
|
name: "temp_unschedulable_body_miss_returns_none",
|
||||||
account: &Account{
|
account: &Account{
|
||||||
|
|||||||
@@ -171,8 +171,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
|||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.Stream)
|
require.True(t, result.Stream)
|
||||||
|
|
||||||
require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体")
|
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射")
|
||||||
require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String())
|
|
||||||
|
|
||||||
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
||||||
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
||||||
@@ -190,7 +189,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
|||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
bodyBytes, ok := rawBody.([]byte)
|
bodyBytes, ok := rawBody.([]byte)
|
||||||
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
|
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
|
||||||
require.Equal(t, body, bodyBytes)
|
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
|
||||||
@@ -253,8 +252,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
|||||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体")
|
require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射")
|
||||||
require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String())
|
|
||||||
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
||||||
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
||||||
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
|
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
|
||||||
@@ -263,6 +261,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
|||||||
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
modelMapping map[string]any // nil = 不配置映射
|
||||||
|
expectedModel string
|
||||||
|
endpoint string // "messages" or "count_tokens"
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Forward: 无映射配置时不改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: nil,
|
||||||
|
expectedModel: "claude-sonnet-4-20250514",
|
||||||
|
endpoint: "messages",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Forward: 空映射配置时不改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{},
|
||||||
|
expectedModel: "claude-sonnet-4-20250514",
|
||||||
|
endpoint: "messages",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Forward: 模型不在映射表中时不改写",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
|
||||||
|
expectedModel: "claude-sonnet-4-20250514",
|
||||||
|
endpoint: "messages",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Forward: 精确匹配映射应改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||||
|
expectedModel: "claude-sonnet-4-5-20241022",
|
||||||
|
endpoint: "messages",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Forward: 通配符映射应改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
|
||||||
|
expectedModel: "claude-sonnet-4-5-20241022",
|
||||||
|
endpoint: "messages",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CountTokens: 无映射配置时不改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: nil,
|
||||||
|
expectedModel: "claude-sonnet-4-20250514",
|
||||||
|
endpoint: "count_tokens",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CountTokens: 模型不在映射表中时不改写",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
|
||||||
|
expectedModel: "claude-sonnet-4-20250514",
|
||||||
|
endpoint: "count_tokens",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CountTokens: 精确匹配映射应改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||||
|
expectedModel: "claude-sonnet-4-5-20241022",
|
||||||
|
endpoint: "count_tokens",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CountTokens: 通配符映射应改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
|
||||||
|
expectedModel: "claude-sonnet-4-5-20241022",
|
||||||
|
endpoint: "count_tokens",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Body: body,
|
||||||
|
Model: tt.model,
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := map[string]any{
|
||||||
|
"api_key": "upstream-key",
|
||||||
|
"base_url": "https://api.anthropic.com",
|
||||||
|
}
|
||||||
|
if tt.modelMapping != nil {
|
||||||
|
credentials["model_mapping"] = tt.modelMapping
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 300,
|
||||||
|
Name: "edge-case-test",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: credentials,
|
||||||
|
Extra: map[string]any{"anthropic_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.endpoint == "messages" {
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
parsed.Stream = false
|
||||||
|
|
||||||
|
upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}`
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamJSON)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
|
||||||
|
"Forward 上游请求体中的模型应为: %s", tt.expectedModel)
|
||||||
|
} else {
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
|
||||||
|
upstreamRespBody := `{"input_tokens":42}`
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
|
||||||
|
"CountTokens 上游请求体中的模型应为: %s", tt.expectedModel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields
|
||||||
|
// 确保模型映射只替换 model 字段,不影响请求体中的其他字段
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
|
||||||
|
// 包含复杂字段的请求体:system、thinking、messages
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`)
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Body: body,
|
||||||
|
Model: "claude-sonnet-4-20250514",
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamRespBody := `{"input_tokens":42}`
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 301,
|
||||||
|
Name: "preserve-fields-test",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "upstream-key",
|
||||||
|
"base_url": "https://api.anthropic.com",
|
||||||
|
"model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{"anthropic_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sentBody := upstream.lastBody
|
||||||
|
require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射")
|
||||||
|
require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改")
|
||||||
|
require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改")
|
||||||
|
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改")
|
||||||
|
require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改")
|
||||||
|
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
|
||||||
|
// 确保空模型名不会触发映射逻辑
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Body: body,
|
||||||
|
Model: "", // 空模型
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamRespBody := `{"input_tokens":10}`
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 302,
|
||||||
|
Name: "empty-model-test",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "upstream-key",
|
||||||
|
"base_url": "https://api.anthropic.com",
|
||||||
|
"model_mapping": map[string]any{"*": "claude-3-opus-20240229"},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{"anthropic_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// 空模型名时,body 应原样透传,不应触发映射
|
||||||
|
require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -3889,7 +3889,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
}
|
}
|
||||||
|
|
||||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime)
|
passthroughBody := parsed.Body
|
||||||
|
passthroughModel := parsed.Model
|
||||||
|
if passthroughModel != "" {
|
||||||
|
if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel {
|
||||||
|
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
|
||||||
|
logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||||
|
passthroughModel = mappedModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
body := parsed.Body
|
body := parsed.Body
|
||||||
@@ -4574,7 +4583,7 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
targetURL = validatedURL + "/v1/messages"
|
targetURL = validatedURL + "/v1/messages?beta=true"
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||||
@@ -4954,7 +4963,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
targetURL = validatedURL + "/v1/messages"
|
targetURL = validatedURL + "/v1/messages?beta=true"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -6781,7 +6790,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||||
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body)
|
passthroughBody := parsed.Body
|
||||||
|
if reqModel := parsed.Model; reqModel != "" {
|
||||||
|
if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel {
|
||||||
|
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
|
||||||
|
logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
body := parsed.Body
|
body := parsed.Body
|
||||||
@@ -7072,7 +7088,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||||
@@ -7119,7 +7135,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -122,6 +122,28 @@ func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
|
|||||||
body: []byte(`overloaded service`),
|
body: []byte(`overloaded service`),
|
||||||
expected: ErrorPolicyTempUnscheduled,
|
expected: ErrorPolicyTempUnscheduled,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "gemini_apikey_temp_unschedulable_401_second_hit_returns_none",
|
||||||
|
account: &Account{
|
||||||
|
ID: 105,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(401),
|
||||||
|
"keywords": []any{"unauthorized"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 401,
|
||||||
|
body: []byte(`unauthorized`),
|
||||||
|
expected: ErrorPolicyNone,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "gemini_custom_codes_override_temp_unschedulable",
|
name: "gemini_custom_codes_override_temp_unschedulable",
|
||||||
account: &Account{
|
account: &Account{
|
||||||
|
|||||||
@@ -19,8 +19,10 @@ import (
|
|||||||
|
|
||||||
// 预编译正则表达式(避免每次调用重新编译)
|
// 预编译正则表达式(避免每次调用重新编译)
|
||||||
var (
|
var (
|
||||||
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
// 匹配 user_id 格式:
|
||||||
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`)
|
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
|
||||||
|
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
|
||||||
|
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
|
||||||
// 匹配 User-Agent 版本号: xxx/x.y.z
|
// 匹配 User-Agent 版本号: xxx/x.y.z
|
||||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||||
)
|
)
|
||||||
@@ -239,13 +241,16 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 匹配格式: user_{64位hex}_account__session_{uuid}
|
// 匹配格式:
|
||||||
|
// 旧格式: user_{64位hex}_account__session_{uuid}
|
||||||
|
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
|
||||||
matches := userIDRegex.FindStringSubmatch(userID)
|
matches := userIDRegex.FindStringSubmatch(userID)
|
||||||
if matches == nil {
|
if matches == nil {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionTail := matches[1] // 原始session UUID
|
// matches[1] = account UUID (可能为空), matches[2] = session UUID
|
||||||
|
sessionTail := matches[2] // 原始session UUID
|
||||||
|
|
||||||
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
||||||
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
||||||
|
|||||||
@@ -263,13 +263,15 @@ type OpenAIGatewayService struct {
|
|||||||
toolCorrector *CodexToolCorrector
|
toolCorrector *CodexToolCorrector
|
||||||
openaiWSResolver OpenAIWSProtocolResolver
|
openaiWSResolver OpenAIWSProtocolResolver
|
||||||
|
|
||||||
openaiWSPoolOnce sync.Once
|
openaiWSPoolOnce sync.Once
|
||||||
openaiWSStateStoreOnce sync.Once
|
openaiWSStateStoreOnce sync.Once
|
||||||
openaiSchedulerOnce sync.Once
|
openaiSchedulerOnce sync.Once
|
||||||
openaiWSPool *openAIWSConnPool
|
openaiWSPassthroughDialerOnce sync.Once
|
||||||
openaiWSStateStore OpenAIWSStateStore
|
openaiWSPool *openAIWSConnPool
|
||||||
openaiScheduler OpenAIAccountScheduler
|
openaiWSStateStore OpenAIWSStateStore
|
||||||
openaiAccountStats *openAIAccountRuntimeStats
|
openaiScheduler OpenAIAccountScheduler
|
||||||
|
openaiWSPassthroughDialer openAIWSClientDialer
|
||||||
|
openaiAccountStats *openAIAccountRuntimeStats
|
||||||
|
|
||||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
|
||||||
coderws "github.com/coder/websocket"
|
coderws "github.com/coder/websocket"
|
||||||
"github.com/coder/websocket/wsjson"
|
"github.com/coder/websocket/wsjson"
|
||||||
)
|
)
|
||||||
@@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct {
|
|||||||
conn *coderws.Conn
|
conn *coderws.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil)
|
||||||
|
|
||||||
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
|
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
|
||||||
if c == nil || c.conn == nil {
|
if c == nil || c.conn == nil {
|
||||||
return errOpenAIWSConnClosed
|
return errOpenAIWSConnClosed
|
||||||
@@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if c == nil || c.conn == nil {
|
||||||
|
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
msgType, payload, err := c.conn.Read(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return coderws.MessageText, nil, err
|
||||||
|
}
|
||||||
|
return msgType, payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||||
|
if c == nil || c.conn == nil {
|
||||||
|
return errOpenAIWSConnClosed
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return c.conn.Write(ctx, msgType, payload)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
|
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
|
||||||
if c == nil || c.conn == nil {
|
if c == nil || c.conn == nil {
|
||||||
return errOpenAIWSConnClosed
|
return errOpenAIWSConnClosed
|
||||||
|
|||||||
@@ -46,9 +46,10 @@ const (
|
|||||||
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
|
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
|
||||||
openAIWSPayloadSizeEstimateMaxItems = 16
|
openAIWSPayloadSizeEstimateMaxItems = 16
|
||||||
|
|
||||||
openAIWSEventFlushBatchSizeDefault = 4
|
openAIWSEventFlushBatchSizeDefault = 4
|
||||||
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
|
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
|
||||||
openAIWSPayloadLogSampleDefault = 0.2
|
openAIWSPayloadLogSampleDefault = 0.2
|
||||||
|
openAIWSPassthroughIdleTimeoutDefault = time.Hour
|
||||||
|
|
||||||
openAIWSStoreDisabledConnModeStrict = "strict"
|
openAIWSStoreDisabledConnModeStrict = "strict"
|
||||||
openAIWSStoreDisabledConnModeAdaptive = "adaptive"
|
openAIWSStoreDisabledConnModeAdaptive = "adaptive"
|
||||||
@@ -904,6 +905,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
|
|||||||
return s.openaiWSPool
|
return s.openaiWSPool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.openaiWSPassthroughDialerOnce.Do(func() {
|
||||||
|
if s.openaiWSPassthroughDialer == nil {
|
||||||
|
s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return s.openaiWSPassthroughDialer
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
|
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
|
||||||
pool := s.getOpenAIWSConnPool()
|
pool := s.getOpenAIWSConnPool()
|
||||||
if pool == nil {
|
if pool == nil {
|
||||||
@@ -967,6 +980,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
|
|||||||
return 15 * time.Minute
|
return 15 * time.Minute
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration {
|
||||||
|
if timeout := s.openAIWSReadTimeout(); timeout > 0 {
|
||||||
|
return timeout
|
||||||
|
}
|
||||||
|
return openAIWSPassthroughIdleTimeoutDefault
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
|
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
|
||||||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
|
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
|
||||||
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
|
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
|
||||||
@@ -2322,7 +2342,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
|
|
||||||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||||||
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
||||||
ingressMode := OpenAIWSIngressModeShared
|
ingressMode := OpenAIWSIngressModeCtxPool
|
||||||
if modeRouterV2Enabled {
|
if modeRouterV2Enabled {
|
||||||
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
|
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
|
||||||
if ingressMode == OpenAIWSIngressModeOff {
|
if ingressMode == OpenAIWSIngressModeOff {
|
||||||
@@ -2332,6 +2352,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
switch ingressMode {
|
||||||
|
case OpenAIWSIngressModePassthrough:
|
||||||
|
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||||||
|
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
|
||||||
|
}
|
||||||
|
return s.proxyResponsesWebSocketV2Passthrough(
|
||||||
|
ctx,
|
||||||
|
c,
|
||||||
|
clientConn,
|
||||||
|
account,
|
||||||
|
token,
|
||||||
|
firstClientMessage,
|
||||||
|
hooks,
|
||||||
|
wsDecision,
|
||||||
|
)
|
||||||
|
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
||||||
|
// continue
|
||||||
|
default:
|
||||||
|
return NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
"websocket mode only supports ctx_pool/passthrough",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||||||
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
|
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
|
|||||||
require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
|
require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
|
||||||
require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
|
require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
|
||||||
|
|
||||||
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
|
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case serverErr := <-serverErrCh:
|
case serverErr := <-serverErrCh:
|
||||||
@@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe
|
|||||||
require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
|
require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
upstreamConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPassthroughDialer: captureDialer,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 452,
|
||||||
|
Name: "openai-ingress-passthrough",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
resultCh := make(chan *OpenAIForwardResult, 1)
|
||||||
|
hooks := &OpenAIWSIngressHooks{
|
||||||
|
AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
|
||||||
|
if turnErr == nil && result != nil {
|
||||||
|
resultCh <- result
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := r.Clone(r.Context())
|
||||||
|
req.Header = req.Header.Clone()
|
||||||
|
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if readErr != nil {
|
||||||
|
serverErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
|
||||||
|
}))
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, event, readErr := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||||
|
require.Equal(t, "resp_passthrough_turn_1", gjson.GetBytes(event, "response.id").String())
|
||||||
|
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case serverErr := <-serverErrCh:
|
||||||
|
require.NoError(t, serverErr)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("等待 passthrough websocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
require.Equal(t, "resp_passthrough_turn_1", result.RequestID)
|
||||||
|
require.True(t, result.OpenAIWSMode)
|
||||||
|
require.Equal(t, 2, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("未收到 passthrough turn 结果回调")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, captureDialer.DialCount(), "passthrough 模式应直接建立上游 websocket")
|
||||||
|
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
|
|||||||
return event, nil
|
return event, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *openAIWSCaptureConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
payload, err := c.ReadMessage(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return coderws.MessageText, nil, err
|
||||||
|
}
|
||||||
|
return coderws.MessageText, payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAIWSCaptureConn) WriteFrame(ctx context.Context, _ coderws.MessageType, payload []byte) error {
|
||||||
|
return c.WriteJSON(ctx, json.RawMessage(payload))
|
||||||
|
}
|
||||||
|
|
||||||
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
|
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
|
||||||
_ = ctx
|
_ = ctx
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
|
|||||||
switch mode {
|
switch mode {
|
||||||
case OpenAIWSIngressModeOff:
|
case OpenAIWSIngressModeOff:
|
||||||
return openAIWSHTTPDecision("account_mode_off")
|
return openAIWSHTTPDecision("account_mode_off")
|
||||||
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough:
|
||||||
// continue
|
// continue
|
||||||
|
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
||||||
|
// 历史值兼容:按 ctx_pool 处理。
|
||||||
|
mode = OpenAIWSIngressModeCtxPool
|
||||||
default:
|
default:
|
||||||
return openAIWSHTTPDecision("account_mode_off")
|
return openAIWSHTTPDecision("account_mode_off")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
|||||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||||
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
|
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
|
||||||
|
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Concurrency: 1,
|
Concurrency: 1,
|
||||||
Extra: map[string]any{
|
Extra: map[string]any{
|
||||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
|
t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) {
|
||||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
|
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
|
||||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||||
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
|
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("off mode routes to http", func(t *testing.T) {
|
t.Run("off mode routes to http", func(t *testing.T) {
|
||||||
@@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
|||||||
require.Equal(t, "account_mode_off", decision.Reason)
|
require.Equal(t, "account_mode_off", decision.Reason)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
|
t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) {
|
||||||
legacyAccount := &Account{
|
legacyAccount := &Account{
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeAPIKey,
|
Type: AccountTypeAPIKey,
|
||||||
@@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
|
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
|
||||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||||
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
|
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("passthrough mode routes to ws v2", func(t *testing.T) {
|
||||||
|
passthroughAccount := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount)
|
||||||
|
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||||
|
require.Equal(t, "ws_v2_mode_passthrough", decision.Reason)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
|
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
|
||||||
@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Extra: map[string]any{
|
Extra: map[string]any{
|
||||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
|
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
|
||||||
|
|||||||
24
backend/internal/service/openai_ws_v2/caddy_adapter.go
Normal file
24
backend/internal/service/openai_ws_v2/caddy_adapter.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
|
||||||
|
// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
|
||||||
|
//
|
||||||
|
// Reference:
|
||||||
|
// - Project: caddyserver/caddy (Apache-2.0)
|
||||||
|
// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
|
||||||
|
// - Files:
|
||||||
|
// - modules/caddyhttp/reverseproxy/streaming.go
|
||||||
|
// - modules/caddyhttp/reverseproxy/reverseproxy.go
|
||||||
|
func runCaddyStyleRelay(
|
||||||
|
ctx context.Context,
|
||||||
|
clientConn FrameConn,
|
||||||
|
upstreamConn FrameConn,
|
||||||
|
firstClientMessage []byte,
|
||||||
|
options RelayOptions,
|
||||||
|
) (RelayResult, *RelayExit) {
|
||||||
|
return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options)
|
||||||
|
}
|
||||||
23
backend/internal/service/openai_ws_v2/entry.go
Normal file
23
backend/internal/service/openai_ws_v2/entry.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// EntryInput 是 passthrough v2 数据面的入口参数。
|
||||||
|
type EntryInput struct {
|
||||||
|
Ctx context.Context
|
||||||
|
ClientConn FrameConn
|
||||||
|
UpstreamConn FrameConn
|
||||||
|
FirstClientMessage []byte
|
||||||
|
Options RelayOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunEntry 是 openai_ws_v2 包对外的统一入口。
|
||||||
|
func RunEntry(input EntryInput) (RelayResult, *RelayExit) {
|
||||||
|
return runCaddyStyleRelay(
|
||||||
|
input.Ctx,
|
||||||
|
input.ClientConn,
|
||||||
|
input.UpstreamConn,
|
||||||
|
input.FirstClientMessage,
|
||||||
|
input.Options,
|
||||||
|
)
|
||||||
|
}
|
||||||
29
backend/internal/service/openai_ws_v2/metrics.go
Normal file
29
backend/internal/service/openai_ws_v2/metrics.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
|
||||||
|
type MetricsSnapshot struct {
|
||||||
|
SemanticMutationTotal int64 `json:"semantic_mutation_total"`
|
||||||
|
UsageParseFailureTotal int64 `json:"usage_parse_failure_total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。
|
||||||
|
passthroughSemanticMutationTotal atomic.Int64
|
||||||
|
passthroughUsageParseFailureTotal atomic.Int64
|
||||||
|
)
|
||||||
|
|
||||||
|
func recordUsageParseFailure() {
|
||||||
|
passthroughUsageParseFailureTotal.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SnapshotMetrics 返回当前 passthrough 指标快照。
|
||||||
|
func SnapshotMetrics() MetricsSnapshot {
|
||||||
|
return MetricsSnapshot{
|
||||||
|
SemanticMutationTotal: passthroughSemanticMutationTotal.Load(),
|
||||||
|
UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(),
|
||||||
|
}
|
||||||
|
}
|
||||||
807
backend/internal/service/openai_ws_v2/passthrough_relay.go
Normal file
807
backend/internal/service/openai_ws_v2/passthrough_relay.go
Normal file
@@ -0,0 +1,807 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FrameConn interface {
|
||||||
|
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
|
||||||
|
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
InputTokens int
|
||||||
|
OutputTokens int
|
||||||
|
CacheCreationInputTokens int
|
||||||
|
CacheReadInputTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayResult struct {
|
||||||
|
RequestModel string
|
||||||
|
Usage Usage
|
||||||
|
RequestID string
|
||||||
|
TerminalEventType string
|
||||||
|
FirstTokenMs *int
|
||||||
|
Duration time.Duration
|
||||||
|
ClientToUpstreamFrames int64
|
||||||
|
UpstreamToClientFrames int64
|
||||||
|
DroppedDownstreamFrames int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayTurnResult struct {
|
||||||
|
RequestModel string
|
||||||
|
Usage Usage
|
||||||
|
RequestID string
|
||||||
|
TerminalEventType string
|
||||||
|
Duration time.Duration
|
||||||
|
FirstTokenMs *int
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayExit struct {
|
||||||
|
Stage string
|
||||||
|
Err error
|
||||||
|
WroteDownstream bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayOptions struct {
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
UpstreamDrainTimeout time.Duration
|
||||||
|
FirstMessageType coderws.MessageType
|
||||||
|
OnUsageParseFailure func(eventType string, usageRaw string)
|
||||||
|
OnTurnComplete func(turn RelayTurnResult)
|
||||||
|
OnTrace func(event RelayTraceEvent)
|
||||||
|
Now func() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayTraceEvent struct {
|
||||||
|
Stage string
|
||||||
|
Direction string
|
||||||
|
MessageType string
|
||||||
|
PayloadBytes int
|
||||||
|
Graceful bool
|
||||||
|
WroteDownstream bool
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type relayState struct {
|
||||||
|
usage Usage
|
||||||
|
requestModel string
|
||||||
|
lastResponseID string
|
||||||
|
terminalEventType string
|
||||||
|
firstTokenMs *int
|
||||||
|
turnTimingByID map[string]*relayTurnTiming
|
||||||
|
}
|
||||||
|
|
||||||
|
type relayExitSignal struct {
|
||||||
|
stage string
|
||||||
|
err error
|
||||||
|
graceful bool
|
||||||
|
wroteDownstream bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type observedUpstreamEvent struct {
|
||||||
|
terminal bool
|
||||||
|
eventType string
|
||||||
|
responseID string
|
||||||
|
usage Usage
|
||||||
|
duration time.Duration
|
||||||
|
firstToken *int
|
||||||
|
}
|
||||||
|
|
||||||
|
type relayTurnTiming struct {
|
||||||
|
startAt time.Time
|
||||||
|
firstTokenMs *int
|
||||||
|
}
|
||||||
|
|
||||||
|
func Relay(
|
||||||
|
ctx context.Context,
|
||||||
|
clientConn FrameConn,
|
||||||
|
upstreamConn FrameConn,
|
||||||
|
firstClientMessage []byte,
|
||||||
|
options RelayOptions,
|
||||||
|
) (RelayResult, *RelayExit) {
|
||||||
|
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
|
||||||
|
if clientConn == nil || upstreamConn == nil {
|
||||||
|
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
nowFn := options.Now
|
||||||
|
if nowFn == nil {
|
||||||
|
nowFn = time.Now
|
||||||
|
}
|
||||||
|
writeTimeout := options.WriteTimeout
|
||||||
|
if writeTimeout <= 0 {
|
||||||
|
writeTimeout = 2 * time.Minute
|
||||||
|
}
|
||||||
|
drainTimeout := options.UpstreamDrainTimeout
|
||||||
|
if drainTimeout <= 0 {
|
||||||
|
drainTimeout = 1200 * time.Millisecond
|
||||||
|
}
|
||||||
|
firstMessageType := options.FirstMessageType
|
||||||
|
if firstMessageType != coderws.MessageBinary {
|
||||||
|
firstMessageType = coderws.MessageText
|
||||||
|
}
|
||||||
|
startAt := nowFn()
|
||||||
|
state := &relayState{requestModel: result.RequestModel}
|
||||||
|
onTrace := options.OnTrace
|
||||||
|
|
||||||
|
relayCtx, relayCancel := context.WithCancel(ctx)
|
||||||
|
defer relayCancel()
|
||||||
|
|
||||||
|
lastActivity := atomic.Int64{}
|
||||||
|
lastActivity.Store(nowFn().UnixNano())
|
||||||
|
markActivity := func() {
|
||||||
|
lastActivity.Store(nowFn().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
|
||||||
|
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
|
||||||
|
}
|
||||||
|
writeClient := func(msgType coderws.MessageType, payload []byte) error {
|
||||||
|
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return clientConn.WriteFrame(writeCtx, msgType, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientToUpstreamFrames := &atomic.Int64{}
|
||||||
|
upstreamToClientFrames := &atomic.Int64{}
|
||||||
|
droppedDownstreamFrames := &atomic.Int64{}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_start",
|
||||||
|
PayloadBytes: len(firstClientMessage),
|
||||||
|
MessageType: relayMessageTypeString(firstMessageType),
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
|
||||||
|
result.Duration = nowFn().Sub(startAt)
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "write_first_message_failed",
|
||||||
|
Direction: "client_to_upstream",
|
||||||
|
MessageType: relayMessageTypeString(firstMessageType),
|
||||||
|
PayloadBytes: len(firstClientMessage),
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
return result, &RelayExit{Stage: "write_upstream", Err: err}
|
||||||
|
}
|
||||||
|
clientToUpstreamFrames.Add(1)
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "write_first_message_ok",
|
||||||
|
Direction: "client_to_upstream",
|
||||||
|
MessageType: relayMessageTypeString(firstMessageType),
|
||||||
|
PayloadBytes: len(firstClientMessage),
|
||||||
|
})
|
||||||
|
markActivity()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 3)
|
||||||
|
dropDownstreamWrites := atomic.Bool{}
|
||||||
|
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
|
||||||
|
go runUpstreamToClient(
|
||||||
|
relayCtx,
|
||||||
|
upstreamConn,
|
||||||
|
writeClient,
|
||||||
|
startAt,
|
||||||
|
nowFn,
|
||||||
|
state,
|
||||||
|
options.OnUsageParseFailure,
|
||||||
|
options.OnTurnComplete,
|
||||||
|
&dropDownstreamWrites,
|
||||||
|
upstreamToClientFrames,
|
||||||
|
droppedDownstreamFrames,
|
||||||
|
markActivity,
|
||||||
|
onTrace,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
|
||||||
|
|
||||||
|
firstExit := <-exitCh
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "first_exit",
|
||||||
|
Direction: relayDirectionFromStage(firstExit.stage),
|
||||||
|
Graceful: firstExit.graceful,
|
||||||
|
WroteDownstream: firstExit.wroteDownstream,
|
||||||
|
Error: relayErrorString(firstExit.err),
|
||||||
|
})
|
||||||
|
combinedWroteDownstream := firstExit.wroteDownstream
|
||||||
|
secondExit := relayExitSignal{graceful: true}
|
||||||
|
hasSecondExit := false
|
||||||
|
|
||||||
|
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
|
||||||
|
if firstExit.stage == "read_client" && firstExit.graceful {
|
||||||
|
dropDownstreamWrites.Store(true)
|
||||||
|
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
|
||||||
|
} else {
|
||||||
|
relayCancel()
|
||||||
|
_ = upstreamConn.Close()
|
||||||
|
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
|
||||||
|
}
|
||||||
|
if hasSecondExit {
|
||||||
|
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "second_exit",
|
||||||
|
Direction: relayDirectionFromStage(secondExit.stage),
|
||||||
|
Graceful: secondExit.graceful,
|
||||||
|
WroteDownstream: secondExit.wroteDownstream,
|
||||||
|
Error: relayErrorString(secondExit.err),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
relayCancel()
|
||||||
|
_ = upstreamConn.Close()
|
||||||
|
|
||||||
|
enrichResult(&result, state, nowFn().Sub(startAt))
|
||||||
|
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
|
||||||
|
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
|
||||||
|
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
|
||||||
|
if firstExit.stage == "read_client" && firstExit.graceful {
|
||||||
|
stage := "client_disconnected"
|
||||||
|
exitErr := firstExit.err
|
||||||
|
if hasSecondExit && !secondExit.graceful {
|
||||||
|
stage = secondExit.stage
|
||||||
|
exitErr = secondExit.err
|
||||||
|
}
|
||||||
|
if exitErr == nil {
|
||||||
|
exitErr = io.EOF
|
||||||
|
}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_exit",
|
||||||
|
Direction: relayDirectionFromStage(stage),
|
||||||
|
Graceful: false,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
Error: relayErrorString(exitErr),
|
||||||
|
})
|
||||||
|
return result, &RelayExit{
|
||||||
|
Stage: stage,
|
||||||
|
Err: exitErr,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_complete",
|
||||||
|
Graceful: true,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
})
|
||||||
|
_ = clientConn.Close()
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
if !firstExit.graceful {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_exit",
|
||||||
|
Direction: relayDirectionFromStage(firstExit.stage),
|
||||||
|
Graceful: false,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
Error: relayErrorString(firstExit.err),
|
||||||
|
})
|
||||||
|
return result, &RelayExit{
|
||||||
|
Stage: firstExit.stage,
|
||||||
|
Err: firstExit.err,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasSecondExit && !secondExit.graceful {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_exit",
|
||||||
|
Direction: relayDirectionFromStage(secondExit.stage),
|
||||||
|
Graceful: false,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
Error: relayErrorString(secondExit.err),
|
||||||
|
})
|
||||||
|
return result, &RelayExit{
|
||||||
|
Stage: secondExit.stage,
|
||||||
|
Err: secondExit.err,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_complete",
|
||||||
|
Graceful: true,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
})
|
||||||
|
_ = clientConn.Close()
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runClientToUpstream(
|
||||||
|
ctx context.Context,
|
||||||
|
clientConn FrameConn,
|
||||||
|
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
|
||||||
|
markActivity func(),
|
||||||
|
forwardedFrames *atomic.Int64,
|
||||||
|
onTrace func(event RelayTraceEvent),
|
||||||
|
exitCh chan<- relayExitSignal,
|
||||||
|
) {
|
||||||
|
for {
|
||||||
|
msgType, payload, err := clientConn.ReadFrame(ctx)
|
||||||
|
if err != nil {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "read_client_failed",
|
||||||
|
Direction: "client_to_upstream",
|
||||||
|
Error: err.Error(),
|
||||||
|
Graceful: isDisconnectError(err),
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
markActivity()
|
||||||
|
if err := writeUpstream(msgType, payload); err != nil {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "write_upstream_failed",
|
||||||
|
Direction: "client_to_upstream",
|
||||||
|
MessageType: relayMessageTypeString(msgType),
|
||||||
|
PayloadBytes: len(payload),
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if forwardedFrames != nil {
|
||||||
|
forwardedFrames.Add(1)
|
||||||
|
}
|
||||||
|
markActivity()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runUpstreamToClient(
|
||||||
|
ctx context.Context,
|
||||||
|
upstreamConn FrameConn,
|
||||||
|
writeClient func(msgType coderws.MessageType, payload []byte) error,
|
||||||
|
startAt time.Time,
|
||||||
|
nowFn func() time.Time,
|
||||||
|
state *relayState,
|
||||||
|
onUsageParseFailure func(eventType string, usageRaw string),
|
||||||
|
onTurnComplete func(turn RelayTurnResult),
|
||||||
|
dropDownstreamWrites *atomic.Bool,
|
||||||
|
forwardedFrames *atomic.Int64,
|
||||||
|
droppedFrames *atomic.Int64,
|
||||||
|
markActivity func(),
|
||||||
|
onTrace func(event RelayTraceEvent),
|
||||||
|
exitCh chan<- relayExitSignal,
|
||||||
|
) {
|
||||||
|
wroteDownstream := false
|
||||||
|
for {
|
||||||
|
msgType, payload, err := upstreamConn.ReadFrame(ctx)
|
||||||
|
if err != nil {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "read_upstream_failed",
|
||||||
|
Direction: "upstream_to_client",
|
||||||
|
Error: err.Error(),
|
||||||
|
Graceful: isDisconnectError(err),
|
||||||
|
WroteDownstream: wroteDownstream,
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{
|
||||||
|
stage: "read_upstream",
|
||||||
|
err: err,
|
||||||
|
graceful: isDisconnectError(err),
|
||||||
|
wroteDownstream: wroteDownstream,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
markActivity()
|
||||||
|
observedEvent := observedUpstreamEvent{}
|
||||||
|
switch msgType {
|
||||||
|
case coderws.MessageText:
|
||||||
|
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
|
||||||
|
case coderws.MessageBinary:
|
||||||
|
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
|
||||||
|
}
|
||||||
|
emitTurnComplete(onTurnComplete, state, observedEvent)
|
||||||
|
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
|
||||||
|
if droppedFrames != nil {
|
||||||
|
droppedFrames.Add(1)
|
||||||
|
}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "drop_downstream_frame",
|
||||||
|
Direction: "upstream_to_client",
|
||||||
|
MessageType: relayMessageTypeString(msgType),
|
||||||
|
PayloadBytes: len(payload),
|
||||||
|
WroteDownstream: wroteDownstream,
|
||||||
|
})
|
||||||
|
if observedEvent.terminal {
|
||||||
|
exitCh <- relayExitSignal{
|
||||||
|
stage: "drain_terminal",
|
||||||
|
graceful: true,
|
||||||
|
wroteDownstream: wroteDownstream,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
markActivity()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := writeClient(msgType, payload); err != nil {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "write_client_failed",
|
||||||
|
Direction: "upstream_to_client",
|
||||||
|
MessageType: relayMessageTypeString(msgType),
|
||||||
|
PayloadBytes: len(payload),
|
||||||
|
WroteDownstream: wroteDownstream,
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wroteDownstream = true
|
||||||
|
if forwardedFrames != nil {
|
||||||
|
forwardedFrames.Add(1)
|
||||||
|
}
|
||||||
|
markActivity()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runIdleWatchdog(
|
||||||
|
ctx context.Context,
|
||||||
|
nowFn func() time.Time,
|
||||||
|
idleTimeout time.Duration,
|
||||||
|
lastActivity *atomic.Int64,
|
||||||
|
onTrace func(event RelayTraceEvent),
|
||||||
|
exitCh chan<- relayExitSignal,
|
||||||
|
) {
|
||||||
|
if idleTimeout <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
|
||||||
|
if checkInterval < time.Second {
|
||||||
|
checkInterval = time.Second
|
||||||
|
}
|
||||||
|
ticker := time.NewTicker(checkInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
last := time.Unix(0, lastActivity.Load())
|
||||||
|
if nowFn().Sub(last) < idleTimeout {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "idle_timeout_triggered",
|
||||||
|
Direction: "watchdog",
|
||||||
|
Error: context.DeadlineExceeded.Error(),
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
|
||||||
|
if onTrace == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
onTrace(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayMessageTypeString(msgType coderws.MessageType) string {
|
||||||
|
switch msgType {
|
||||||
|
case coderws.MessageText:
|
||||||
|
return "text"
|
||||||
|
case coderws.MessageBinary:
|
||||||
|
return "binary"
|
||||||
|
default:
|
||||||
|
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayDirectionFromStage(stage string) string {
|
||||||
|
switch stage {
|
||||||
|
case "read_client", "write_upstream":
|
||||||
|
return "client_to_upstream"
|
||||||
|
case "read_upstream", "write_client", "drain_terminal":
|
||||||
|
return "upstream_to_client"
|
||||||
|
case "idle_timeout":
|
||||||
|
return "watchdog"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayErrorString(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func observeUpstreamMessage(
|
||||||
|
state *relayState,
|
||||||
|
message []byte,
|
||||||
|
startAt time.Time,
|
||||||
|
nowFn func() time.Time,
|
||||||
|
onUsageParseFailure func(eventType string, usageRaw string),
|
||||||
|
) observedUpstreamEvent {
|
||||||
|
if state == nil || len(message) == 0 {
|
||||||
|
return observedUpstreamEvent{}
|
||||||
|
}
|
||||||
|
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
|
||||||
|
eventType := strings.TrimSpace(values[0].String())
|
||||||
|
if eventType == "" {
|
||||||
|
return observedUpstreamEvent{}
|
||||||
|
}
|
||||||
|
responseID := strings.TrimSpace(values[1].String())
|
||||||
|
if responseID == "" {
|
||||||
|
responseID = strings.TrimSpace(values[2].String())
|
||||||
|
}
|
||||||
|
// 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。
|
||||||
|
if responseID == "" && isTerminalEvent(eventType) {
|
||||||
|
responseID = strings.TrimSpace(values[3].String())
|
||||||
|
}
|
||||||
|
now := nowFn()
|
||||||
|
|
||||||
|
if state.firstTokenMs == nil && isTokenEvent(eventType) {
|
||||||
|
ms := int(now.Sub(startAt).Milliseconds())
|
||||||
|
if ms >= 0 {
|
||||||
|
state.firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
|
||||||
|
observed := observedUpstreamEvent{
|
||||||
|
eventType: eventType,
|
||||||
|
responseID: responseID,
|
||||||
|
usage: parsedUsage,
|
||||||
|
}
|
||||||
|
if responseID != "" {
|
||||||
|
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
|
||||||
|
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
|
||||||
|
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
|
||||||
|
if ms >= 0 {
|
||||||
|
turnTiming.firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isTerminalEvent(eventType) {
|
||||||
|
return observed
|
||||||
|
}
|
||||||
|
observed.terminal = true
|
||||||
|
state.terminalEventType = eventType
|
||||||
|
if responseID != "" {
|
||||||
|
state.lastResponseID = responseID
|
||||||
|
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
|
||||||
|
duration := now.Sub(turnTiming.startAt)
|
||||||
|
if duration < 0 {
|
||||||
|
duration = 0
|
||||||
|
}
|
||||||
|
observed.duration = duration
|
||||||
|
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return observed
|
||||||
|
}
|
||||||
|
|
||||||
|
func emitTurnComplete(
|
||||||
|
onTurnComplete func(turn RelayTurnResult),
|
||||||
|
state *relayState,
|
||||||
|
observed observedUpstreamEvent,
|
||||||
|
) {
|
||||||
|
if onTurnComplete == nil || !observed.terminal {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
responseID := strings.TrimSpace(observed.responseID)
|
||||||
|
if responseID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requestModel := ""
|
||||||
|
if state != nil {
|
||||||
|
requestModel = state.requestModel
|
||||||
|
}
|
||||||
|
onTurnComplete(RelayTurnResult{
|
||||||
|
RequestModel: requestModel,
|
||||||
|
Usage: observed.usage,
|
||||||
|
RequestID: responseID,
|
||||||
|
TerminalEventType: observed.eventType,
|
||||||
|
Duration: observed.duration,
|
||||||
|
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
|
||||||
|
if state == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if state.turnTimingByID == nil {
|
||||||
|
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
|
||||||
|
}
|
||||||
|
timing, ok := state.turnTimingByID[responseID]
|
||||||
|
if !ok || timing == nil || timing.startAt.IsZero() {
|
||||||
|
timing = &relayTurnTiming{startAt: now}
|
||||||
|
state.turnTimingByID[responseID] = timing
|
||||||
|
return timing
|
||||||
|
}
|
||||||
|
return timing
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
|
||||||
|
if state == nil || state.turnTimingByID == nil {
|
||||||
|
return relayTurnTiming{}, false
|
||||||
|
}
|
||||||
|
timing, ok := state.turnTimingByID[responseID]
|
||||||
|
if !ok || timing == nil {
|
||||||
|
return relayTurnTiming{}, false
|
||||||
|
}
|
||||||
|
delete(state.turnTimingByID, responseID)
|
||||||
|
return *timing, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSRelayCloneIntPtr(v *int) *int {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := *v
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseUsageAndAccumulate(
|
||||||
|
state *relayState,
|
||||||
|
message []byte,
|
||||||
|
eventType string,
|
||||||
|
onParseFailure func(eventType string, usageRaw string),
|
||||||
|
) Usage {
|
||||||
|
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
|
||||||
|
return Usage{}
|
||||||
|
}
|
||||||
|
usageResult := gjson.GetBytes(message, "response.usage")
|
||||||
|
if !usageResult.Exists() {
|
||||||
|
return Usage{}
|
||||||
|
}
|
||||||
|
usageRaw := strings.TrimSpace(usageResult.Raw)
|
||||||
|
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
|
||||||
|
recordUsageParseFailure()
|
||||||
|
if onParseFailure != nil {
|
||||||
|
onParseFailure(eventType, usageRaw)
|
||||||
|
}
|
||||||
|
return Usage{}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
||||||
|
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
||||||
|
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
||||||
|
|
||||||
|
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
||||||
|
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
||||||
|
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
|
||||||
|
if !inputOK || !outputOK || !cachedOK {
|
||||||
|
recordUsageParseFailure()
|
||||||
|
if onParseFailure != nil {
|
||||||
|
onParseFailure(eventType, usageRaw)
|
||||||
|
}
|
||||||
|
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
|
||||||
|
return Usage{}
|
||||||
|
}
|
||||||
|
parsedUsage := Usage{
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
CacheReadInputTokens: cachedTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
state.usage.InputTokens += parsedUsage.InputTokens
|
||||||
|
state.usage.OutputTokens += parsedUsage.OutputTokens
|
||||||
|
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
||||||
|
return parsedUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
|
||||||
|
if !value.Exists() {
|
||||||
|
return 0, !required
|
||||||
|
}
|
||||||
|
if value.Type != gjson.Number {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return int(value.Int()), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
|
||||||
|
if result == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result.Duration = duration
|
||||||
|
if state == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result.RequestModel = state.requestModel
|
||||||
|
result.Usage = state.usage
|
||||||
|
result.RequestID = state.lastResponseID
|
||||||
|
result.TerminalEventType = state.terminalEventType
|
||||||
|
result.FirstTokenMs = state.firstTokenMs
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDisconnectError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch coderws.CloseStatus(err) {
|
||||||
|
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
message := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||||
|
if message == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(message, "failed to read frame header: eof") ||
|
||||||
|
strings.Contains(message, "unexpected eof") ||
|
||||||
|
strings.Contains(message, "use of closed network connection") ||
|
||||||
|
strings.Contains(message, "connection reset by peer") ||
|
||||||
|
strings.Contains(message, "broken pipe")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTerminalEvent(eventType string) bool {
|
||||||
|
switch eventType {
|
||||||
|
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldParseUsage(eventType string) bool {
|
||||||
|
switch eventType {
|
||||||
|
case "response.completed", "response.done", "response.failed":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTokenEvent(eventType string) bool {
|
||||||
|
if eventType == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch eventType {
|
||||||
|
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.Contains(eventType, ".delta") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(eventType, "response.output_text") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(eventType, "response.output") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return eventType == "response.completed" || eventType == "response.done"
|
||||||
|
}
|
||||||
|
|
||||||
|
func minDuration(a, b time.Duration) time.Duration {
|
||||||
|
if a <= 0 {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
if b <= 0 {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 200 * time.Millisecond
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case sig := <-exitCh:
|
||||||
|
return sig, true
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return relayExitSignal{}, false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,432 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRunEntry_DelegatesRelay(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
result, relayExit := RunEntry(EntryInput{
|
||||||
|
Ctx: context.Background(),
|
||||||
|
ClientConn: clientConn,
|
||||||
|
UpstreamConn: upstreamConn,
|
||||||
|
FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`),
|
||||||
|
})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Equal(t, "resp_entry", result.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("read client eof", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
runClientToUpstream(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn(nil, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||||
|
func() {},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "read_client", sig.stage)
|
||||||
|
require.True(t, sig.graceful)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write upstream failed", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
runClientToUpstream(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
||||||
|
}, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
|
||||||
|
func() {},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "write_upstream", sig.stage)
|
||||||
|
require.False(t, sig.graceful)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("forwarded counter and trace callback", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
forwarded := &atomic.Int64{}
|
||||||
|
traces := make([]RelayTraceEvent, 0, 2)
|
||||||
|
runClientToUpstream(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
||||||
|
}, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||||
|
func() {},
|
||||||
|
forwarded,
|
||||||
|
func(event RelayTraceEvent) {
|
||||||
|
traces = append(traces, event)
|
||||||
|
},
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "read_client", sig.stage)
|
||||||
|
require.Equal(t, int64(1), forwarded.Load())
|
||||||
|
require.NotEmpty(t, traces)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("read upstream eof", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
drop := &atomic.Bool{}
|
||||||
|
drop.Store(false)
|
||||||
|
runUpstreamToClient(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn(nil, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||||
|
time.Now(),
|
||||||
|
time.Now,
|
||||||
|
&relayState{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
drop,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
func() {},
|
||||||
|
nil,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "read_upstream", sig.stage)
|
||||||
|
require.True(t, sig.graceful)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write client failed", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
drop := &atomic.Bool{}
|
||||||
|
drop.Store(false)
|
||||||
|
runUpstreamToClient(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)},
|
||||||
|
}, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") },
|
||||||
|
time.Now(),
|
||||||
|
time.Now,
|
||||||
|
&relayState{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
drop,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
func() {},
|
||||||
|
nil,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "write_client", sig.stage)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("drop downstream and stop on terminal", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
drop := &atomic.Bool{}
|
||||||
|
drop.Store(true)
|
||||||
|
dropped := &atomic.Int64{}
|
||||||
|
runUpstreamToClient(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||||
|
time.Now(),
|
||||||
|
time.Now,
|
||||||
|
&relayState{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
drop,
|
||||||
|
nil,
|
||||||
|
dropped,
|
||||||
|
func() {},
|
||||||
|
nil,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "drain_terminal", sig.stage)
|
||||||
|
require.True(t, sig.graceful)
|
||||||
|
require.Equal(t, int64(1), dropped.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
lastActivity := &atomic.Int64{}
|
||||||
|
lastActivity.Store(time.Now().UnixNano())
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh)
|
||||||
|
select {
|
||||||
|
case <-exitCh:
|
||||||
|
t.Fatal("unexpected idle timeout signal")
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHelperFunctionsCoverage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
require.Equal(t, "text", relayMessageTypeString(coderws.MessageText))
|
||||||
|
require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary))
|
||||||
|
require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(")
|
||||||
|
|
||||||
|
require.Equal(t, "", relayErrorString(nil))
|
||||||
|
require.Equal(t, "x", relayErrorString(errors.New("x")))
|
||||||
|
|
||||||
|
require.True(t, isDisconnectError(io.EOF))
|
||||||
|
require.True(t, isDisconnectError(net.ErrClosed))
|
||||||
|
require.True(t, isDisconnectError(context.Canceled))
|
||||||
|
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway}))
|
||||||
|
require.True(t, isDisconnectError(errors.New("broken pipe")))
|
||||||
|
require.False(t, isDisconnectError(errors.New("unrelated")))
|
||||||
|
|
||||||
|
require.True(t, isTokenEvent("response.output_text.delta"))
|
||||||
|
require.True(t, isTokenEvent("response.output_audio.delta"))
|
||||||
|
require.True(t, isTokenEvent("response.completed"))
|
||||||
|
require.False(t, isTokenEvent(""))
|
||||||
|
require.False(t, isTokenEvent("response.created"))
|
||||||
|
|
||||||
|
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second))
|
||||||
|
require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second))
|
||||||
|
require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second))
|
||||||
|
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0))
|
||||||
|
|
||||||
|
ch := make(chan relayExitSignal, 1)
|
||||||
|
ch <- relayExitSignal{stage: "ok"}
|
||||||
|
sig, ok := waitRelayExit(ch, 10*time.Millisecond)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "ok", sig.stage)
|
||||||
|
ch <- relayExitSignal{stage: "ok2"}
|
||||||
|
sig, ok = waitRelayExit(ch, 0)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "ok2", sig.stage)
|
||||||
|
_, ok = waitRelayExit(ch, 10*time.Millisecond)
|
||||||
|
require.False(t, ok)
|
||||||
|
|
||||||
|
n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, 3, n)
|
||||||
|
_, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true)
|
||||||
|
require.False(t, ok)
|
||||||
|
n, ok = parseUsageIntField(gjson.Result{}, false)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, 0, n)
|
||||||
|
_, ok = parseUsageIntField(gjson.Result{}, true)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUsageAndEnrichCoverage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
state := &relayState{}
|
||||||
|
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil)
|
||||||
|
require.Equal(t, 0, state.usage.InputTokens)
|
||||||
|
|
||||||
|
parseUsageAndAccumulate(
|
||||||
|
state,
|
||||||
|
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`),
|
||||||
|
"response.completed",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage")
|
||||||
|
require.Equal(t, 0, state.usage.OutputTokens)
|
||||||
|
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
parseUsageAndAccumulate(
|
||||||
|
state,
|
||||||
|
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`),
|
||||||
|
"response.completed",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage")
|
||||||
|
require.Equal(t, 0, state.usage.OutputTokens)
|
||||||
|
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
|
||||||
|
require.Equal(t, 2, state.usage.InputTokens)
|
||||||
|
require.Equal(t, 1, state.usage.OutputTokens)
|
||||||
|
require.Equal(t, 1, state.usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
result := &RelayResult{}
|
||||||
|
enrichResult(result, state, 5*time.Millisecond)
|
||||||
|
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 5*time.Millisecond, result.Duration)
|
||||||
|
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
|
||||||
|
require.Equal(t, 2, state.usage.InputTokens)
|
||||||
|
enrichResult(nil, state, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmitTurnCompleteCoverage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 非 terminal 事件不应触发。
|
||||||
|
called := 0
|
||||||
|
emitTurnComplete(func(turn RelayTurnResult) {
|
||||||
|
called++
|
||||||
|
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
|
||||||
|
terminal: false,
|
||||||
|
eventType: "response.output_text.delta",
|
||||||
|
responseID: "resp_ignored",
|
||||||
|
usage: Usage{InputTokens: 1},
|
||||||
|
})
|
||||||
|
require.Equal(t, 0, called)
|
||||||
|
|
||||||
|
// 缺少 response_id 时不应触发。
|
||||||
|
emitTurnComplete(func(turn RelayTurnResult) {
|
||||||
|
called++
|
||||||
|
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
|
||||||
|
terminal: true,
|
||||||
|
eventType: "response.completed",
|
||||||
|
})
|
||||||
|
require.Equal(t, 0, called)
|
||||||
|
|
||||||
|
// terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。
|
||||||
|
var got RelayTurnResult
|
||||||
|
emitTurnComplete(func(turn RelayTurnResult) {
|
||||||
|
called++
|
||||||
|
got = turn
|
||||||
|
}, nil, observedUpstreamEvent{
|
||||||
|
terminal: true,
|
||||||
|
eventType: "response.completed",
|
||||||
|
responseID: "resp_emit",
|
||||||
|
usage: Usage{InputTokens: 2, OutputTokens: 3},
|
||||||
|
})
|
||||||
|
require.Equal(t, 1, called)
|
||||||
|
require.Equal(t, "resp_emit", got.RequestID)
|
||||||
|
require.Equal(t, "response.completed", got.TerminalEventType)
|
||||||
|
require.Equal(t, 2, got.Usage.InputTokens)
|
||||||
|
require.Equal(t, 3, got.Usage.OutputTokens)
|
||||||
|
require.Equal(t, "", got.RequestModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure}))
|
||||||
|
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd}))
|
||||||
|
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure}))
|
||||||
|
require.True(t, isDisconnectError(errors.New("connection reset by peer")))
|
||||||
|
require.False(t, isDisconnectError(errors.New(" ")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTokenEventCoverageBranches(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
require.False(t, isTokenEvent("response.in_progress"))
|
||||||
|
require.False(t, isTokenEvent("response.output_item.added"))
|
||||||
|
require.True(t, isTokenEvent("response.output_audio.delta"))
|
||||||
|
require.True(t, isTokenEvent("response.output"))
|
||||||
|
require.True(t, isTokenEvent("response.done"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Unix(100, 0)
|
||||||
|
// nil state
|
||||||
|
require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now))
|
||||||
|
_, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil")
|
||||||
|
require.False(t, ok)
|
||||||
|
|
||||||
|
state := &relayState{}
|
||||||
|
timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now)
|
||||||
|
require.NotNil(t, timing)
|
||||||
|
require.Equal(t, now, timing.startAt)
|
||||||
|
|
||||||
|
// 再次获取返回同一条 timing
|
||||||
|
timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second))
|
||||||
|
require.NotNil(t, timing2)
|
||||||
|
require.Equal(t, now, timing2.startAt)
|
||||||
|
|
||||||
|
// 删除存在键
|
||||||
|
deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, now, deleted.startAt)
|
||||||
|
|
||||||
|
// 删除不存在键
|
||||||
|
_, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a")
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
state := &relayState{requestModel: "gpt-5"}
|
||||||
|
startAt := time.Unix(0, 0)
|
||||||
|
now := startAt
|
||||||
|
nowFn := func() time.Time {
|
||||||
|
now = now.Add(5 * time.Millisecond)
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
|
||||||
|
// 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。
|
||||||
|
observed := observeUpstreamMessage(
|
||||||
|
state,
|
||||||
|
[]byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`),
|
||||||
|
startAt,
|
||||||
|
nowFn,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.False(t, observed.terminal)
|
||||||
|
require.Equal(t, "", observed.responseID)
|
||||||
|
|
||||||
|
// terminal:允许兜底用顶层 id(用于兼容少数字段变体)。
|
||||||
|
observed = observeUpstreamMessage(
|
||||||
|
state,
|
||||||
|
[]byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
startAt,
|
||||||
|
nowFn,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.True(t, observed.terminal)
|
||||||
|
require.Equal(t, "resp_fallback", observed.responseID)
|
||||||
|
}
|
||||||
752
backend/internal/service/openai_ws_v2/passthrough_relay_test.go
Normal file
752
backend/internal/service/openai_ws_v2/passthrough_relay_test.go
Normal file
@@ -0,0 +1,752 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type passthroughTestFrame struct {
|
||||||
|
msgType coderws.MessageType
|
||||||
|
payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type passthroughTestFrameConn struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
writes []passthroughTestFrame
|
||||||
|
readCh chan passthroughTestFrame
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
type delayedReadFrameConn struct {
|
||||||
|
base FrameConn
|
||||||
|
firstDelay time.Duration
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
type closeSpyFrameConn struct {
|
||||||
|
closeCalls atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn {
|
||||||
|
c := &passthroughTestFrameConn{
|
||||||
|
readCh: make(chan passthroughTestFrame, len(frames)+1),
|
||||||
|
}
|
||||||
|
for _, frame := range frames {
|
||||||
|
copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)}
|
||||||
|
c.readCh <- copied
|
||||||
|
}
|
||||||
|
if autoClose {
|
||||||
|
close(c.readCh)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return coderws.MessageText, nil, ctx.Err()
|
||||||
|
case frame, ok := <-c.readCh:
|
||||||
|
if !ok {
|
||||||
|
return coderws.MessageText, nil, io.EOF
|
||||||
|
}
|
||||||
|
return frame.msgType, append([]byte(nil), frame.payload...), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *passthroughTestFrameConn) Close() error {
|
||||||
|
c.once.Do(func() {
|
||||||
|
defer func() { _ = recover() }()
|
||||||
|
close(c.readCh)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
out := make([]passthroughTestFrame, len(c.writes))
|
||||||
|
copy(out, c.writes)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if c == nil || c.base == nil {
|
||||||
|
return coderws.MessageText, nil, io.EOF
|
||||||
|
}
|
||||||
|
c.once.Do(func() {
|
||||||
|
if c.firstDelay > 0 {
|
||||||
|
timer := time.NewTimer(c.firstDelay)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return c.base.ReadFrame(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||||
|
if c == nil || c.base == nil {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return c.base.WriteFrame(ctx, msgType, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *delayedReadFrameConn) Close() error {
|
||||||
|
if c == nil || c.base == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.base.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
<-ctx.Done()
|
||||||
|
return coderws.MessageText, nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeSpyFrameConn) Close() error {
|
||||||
|
if c != nil {
|
||||||
|
c.closeCalls.Add(1)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeSpyFrameConn) CloseCalls() int32 {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return c.closeCalls.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_BasicRelayAndUsage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Equal(t, "gpt-5.3-codex", result.RequestModel)
|
||||||
|
require.Equal(t, "resp_123", result.RequestID)
|
||||||
|
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||||
|
require.Equal(t, 7, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 2, result.Usage.CacheReadInputTokens)
|
||||||
|
require.NotNil(t, result.FirstTokenMs)
|
||||||
|
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
|
||||||
|
require.Equal(t, int64(1), result.UpstreamToClientFrames)
|
||||||
|
require.Equal(t, int64(0), result.DroppedDownstreamFrames)
|
||||||
|
|
||||||
|
upstreamWrites := upstreamConn.Writes()
|
||||||
|
require.Len(t, upstreamWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
|
||||||
|
require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload))
|
||||||
|
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
|
||||||
|
require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
|
||||||
|
upstreamWrites := upstreamConn.Writes()
|
||||||
|
require.Len(t, upstreamWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
|
||||||
|
require.Equal(t, firstPayload, upstreamWrites[0].payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_UpstreamDisconnect(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 上游立即关闭(EOF),客户端不发送额外帧
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
// 上游 EOF 属于 disconnect,标记为 graceful
|
||||||
|
require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect")
|
||||||
|
require.Equal(t, "gpt-4o", result.RequestModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_ClientDisconnect(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 客户端立即关闭(EOF),上游阻塞读取直到 context 取消
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态")
|
||||||
|
require.Equal(t, "client_disconnected", relayExit.Stage)
|
||||||
|
require.Equal(t, "gpt-4o", result.RequestModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, true)
|
||||||
|
upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
upstreamConn := &delayedReadFrameConn{
|
||||||
|
base: upstreamBase,
|
||||||
|
firstDelay: 80 * time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
UpstreamDrainTimeout: 400 * time.Millisecond,
|
||||||
|
})
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
require.Equal(t, "client_disconnected", relayExit.Stage)
|
||||||
|
require.Equal(t, "resp_drain", result.RequestID)
|
||||||
|
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||||
|
require.Equal(t, 6, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 1, result.Usage.CacheReadInputTokens)
|
||||||
|
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
|
||||||
|
require.Equal(t, int64(0), result.UpstreamToClientFrames)
|
||||||
|
require.Equal(t, int64(1), result.DroppedDownstreamFrames)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_IdleTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 客户端和上游都不发送帧,idle timeout 应触发
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 使用快进时间来加速 idle timeout
|
||||||
|
now := time.Now()
|
||||||
|
callCount := 0
|
||||||
|
nowFn := func() time.Time {
|
||||||
|
callCount++
|
||||||
|
// 前几次调用返回正常时间(初始化阶段),之后快进
|
||||||
|
if callCount <= 5 {
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
return now.Add(time.Hour) // 快进到超时
|
||||||
|
}
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
IdleTimeout: 2 * time.Second,
|
||||||
|
Now: nowFn,
|
||||||
|
})
|
||||||
|
require.NotNil(t, relayExit, "应因 idle timeout 退出")
|
||||||
|
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||||||
|
require.Equal(t, "gpt-4o", result.RequestModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := &closeSpyFrameConn{}
|
||||||
|
upstreamConn := &closeSpyFrameConn{}
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
callCount := 0
|
||||||
|
nowFn := func() time.Time {
|
||||||
|
callCount++
|
||||||
|
if callCount <= 5 {
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
return now.Add(time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
IdleTimeout: 2 * time.Second,
|
||||||
|
Now: nowFn,
|
||||||
|
})
|
||||||
|
require.NotNil(t, relayExit, "应因 idle timeout 退出")
|
||||||
|
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||||||
|
require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code")
|
||||||
|
require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_NilConnections(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("nil client conn", func(t *testing.T) {
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||||||
|
_, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
require.Equal(t, "relay_init", relayExit.Stage)
|
||||||
|
require.Contains(t, relayExit.Err.Error(), "nil")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil upstream conn", func(t *testing.T) {
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, true)
|
||||||
|
_, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{})
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
require.Equal(t, "relay_init", relayExit.Stage)
|
||||||
|
require.Contains(t, relayExit.Err.Error(), "nil")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_MultipleUpstreamMessages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 上游发送多个事件(delta + completed),验证多帧中继和 usage 聚合
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Equal(t, "resp_multi", result.RequestID)
|
||||||
|
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||||
|
require.Equal(t, 10, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||||
|
require.NotNil(t, result.FirstTokenMs)
|
||||||
|
|
||||||
|
// 验证所有 3 个上游帧都转发给了客户端
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
turns := make([]RelayTurnResult, 0, 2)
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
OnTurnComplete: func(turn RelayTurnResult) {
|
||||||
|
turns = append(turns, turn)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Len(t, turns, 2)
|
||||||
|
require.Equal(t, "resp_turn_1", turns[0].RequestID)
|
||||||
|
require.Equal(t, "response.completed", turns[0].TerminalEventType)
|
||||||
|
require.Equal(t, 2, turns[0].Usage.InputTokens)
|
||||||
|
require.Equal(t, 1, turns[0].Usage.OutputTokens)
|
||||||
|
require.Equal(t, "resp_turn_2", turns[1].RequestID)
|
||||||
|
require.Equal(t, "response.failed", turns[1].TerminalEventType)
|
||||||
|
require.Equal(t, 3, turns[1].Usage.InputTokens)
|
||||||
|
require.Equal(t, 4, turns[1].Usage.OutputTokens)
|
||||||
|
require.Equal(t, 5, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
base := time.Unix(0, 0)
|
||||||
|
var nowTick atomic.Int64
|
||||||
|
nowFn := func() time.Time {
|
||||||
|
step := nowTick.Add(1)
|
||||||
|
return base.Add(time.Duration(step) * 5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
var turn RelayTurnResult
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
Now: nowFn,
|
||||||
|
OnTurnComplete: func(current RelayTurnResult) {
|
||||||
|
turn = current
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Equal(t, "resp_metric", turn.RequestID)
|
||||||
|
require.Equal(t, "response.completed", turn.TerminalEventType)
|
||||||
|
require.NotNil(t, turn.FirstTokenMs)
|
||||||
|
require.GreaterOrEqual(t, *turn.FirstTokenMs, 0)
|
||||||
|
require.Greater(t, turn.Duration.Milliseconds(), int64(0))
|
||||||
|
require.NotNil(t, result.FirstTokenMs)
|
||||||
|
require.Greater(t, result.Duration.Milliseconds(), int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_BinaryFramePassthrough(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 验证 binary frame 被透传但不进行 usage 解析
|
||||||
|
binaryPayload := []byte{0x00, 0x01, 0x02, 0x03}
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageBinary,
|
||||||
|
payload: binaryPayload,
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
// binary frame 不解析 usage
|
||||||
|
require.Equal(t, 0, result.Usage.InputTokens)
|
||||||
|
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
|
||||||
|
require.Equal(t, binaryPayload, clientWrites[0].payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageBinary,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Equal(t, 0, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, "", result.RequestID)
|
||||||
|
require.Equal(t, "", result.TerminalEventType)
|
||||||
|
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: errorEvent,
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
|
||||||
|
require.Equal(t, errorEvent, clientWrites[0].payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_PreservesFirstMessageType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
FirstMessageType: coderws.MessageBinary,
|
||||||
|
})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
|
||||||
|
upstreamWrites := upstreamConn.Writes()
|
||||||
|
require.Len(t, upstreamWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType)
|
||||||
|
require.Equal(t, firstPayload, upstreamWrites[0].payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) {
|
||||||
|
baseline := SnapshotMetrics().UsageParseFailureTotal
|
||||||
|
|
||||||
|
// 上游发送无效 JSON(非 usage 格式),不应影响透传
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
// usage 解析失败,值为 0 但不影响透传
|
||||||
|
require.Equal(t, 0, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||||
|
|
||||||
|
// 帧仍然被转发
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 1)
|
||||||
|
require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 上游连接立即关闭,首包写入失败
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||||||
|
_ = upstreamConn.Close()
|
||||||
|
|
||||||
|
// 覆盖 WriteFrame 使其返回错误
|
||||||
|
errConn := &errorOnWriteFrameConn{}
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{})
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
require.Equal(t, "write_upstream", relayExit.Stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_ContextCanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
|
||||||
|
// 立即取消 context
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
// context 取消导致写首包失败
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stages := make([]string, 0, 8)
|
||||||
|
var stagesMu sync.Mutex
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
OnTrace: func(event RelayTraceEvent) {
|
||||||
|
stagesMu.Lock()
|
||||||
|
stages = append(stages, event.Stage)
|
||||||
|
stagesMu.Unlock()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
stagesMu.Lock()
|
||||||
|
capturedStages := append([]string(nil), stages...)
|
||||||
|
stagesMu.Unlock()
|
||||||
|
require.Contains(t, capturedStages, "relay_start")
|
||||||
|
require.Contains(t, capturedStages, "write_first_message_ok")
|
||||||
|
require.Contains(t, capturedStages, "first_exit")
|
||||||
|
require.Contains(t, capturedStages, "relay_complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_TraceEvents_IdleTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
callCount := 0
|
||||||
|
nowFn := func() time.Time {
|
||||||
|
callCount++
|
||||||
|
if callCount <= 5 {
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
return now.Add(time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
stages := make([]string, 0, 8)
|
||||||
|
var stagesMu sync.Mutex
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
IdleTimeout: 2 * time.Second,
|
||||||
|
Now: nowFn,
|
||||||
|
OnTrace: func(event RelayTraceEvent) {
|
||||||
|
stagesMu.Lock()
|
||||||
|
stages = append(stages, event.Stage)
|
||||||
|
stagesMu.Unlock()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||||||
|
stagesMu.Lock()
|
||||||
|
capturedStages := append([]string(nil), stages...)
|
||||||
|
stagesMu.Unlock()
|
||||||
|
require.Contains(t, capturedStages, "idle_timeout_triggered")
|
||||||
|
require.Contains(t, capturedStages, "relay_exit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。
|
||||||
|
type errorOnWriteFrameConn struct{}
|
||||||
|
|
||||||
|
func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
<-ctx.Done()
|
||||||
|
return coderws.MessageText, nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error {
|
||||||
|
return errors.New("write failed: connection refused")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *errorOnWriteFrameConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
367
backend/internal/service/openai_ws_v2_passthrough_adapter.go
Normal file
367
backend/internal/service/openai_ws_v2_passthrough_adapter.go
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIWSClientFrameConn struct {
|
||||||
|
conn *coderws.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||||||
|
|
||||||
|
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||||||
|
|
||||||
|
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if c == nil || c.conn == nil {
|
||||||
|
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return c.conn.Read(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||||
|
if c == nil || c.conn == nil {
|
||||||
|
return errOpenAIWSConnClosed
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return c.conn.Write(ctx, msgType, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAIWSClientFrameConn) Close() error {
|
||||||
|
if c == nil || c.conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_ = c.conn.Close(coderws.StatusNormalClosure, "")
|
||||||
|
_ = c.conn.CloseNow()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||||
|
ctx context.Context,
|
||||||
|
c *gin.Context,
|
||||||
|
clientConn *coderws.Conn,
|
||||||
|
account *Account,
|
||||||
|
token string,
|
||||||
|
firstClientMessage []byte,
|
||||||
|
hooks *OpenAIWSIngressHooks,
|
||||||
|
wsDecision OpenAIWSProtocolDecision,
|
||||||
|
) error {
|
||||||
|
if s == nil {
|
||||||
|
return errors.New("service is nil")
|
||||||
|
}
|
||||||
|
if clientConn == nil {
|
||||||
|
return errors.New("client websocket is nil")
|
||||||
|
}
|
||||||
|
if account == nil {
|
||||||
|
return errors.New("account is nil")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(token) == "" {
|
||||||
|
return errors.New("token is empty")
|
||||||
|
}
|
||||||
|
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
||||||
|
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
||||||
|
account.ID,
|
||||||
|
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
|
||||||
|
openaiwsv2RelayMessageTypeName(coderws.MessageText),
|
||||||
|
len(firstClientMessage),
|
||||||
|
)
|
||||||
|
|
||||||
|
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build ws url: %w", err)
|
||||||
|
}
|
||||||
|
wsHost := "-"
|
||||||
|
wsPath := "-"
|
||||||
|
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
|
||||||
|
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
|
||||||
|
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
|
||||||
|
}
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
|
||||||
|
account.ID,
|
||||||
|
wsHost,
|
||||||
|
wsPath,
|
||||||
|
account.ProxyID != nil && account.Proxy != nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
isCodexCLI := false
|
||||||
|
if c != nil {
|
||||||
|
isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||||
|
}
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||||
|
isCodexCLI = true
|
||||||
|
}
|
||||||
|
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := s.getOpenAIWSPassthroughDialer()
|
||||||
|
if dialer == nil {
|
||||||
|
return errors.New("openai ws passthrough dialer is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
|
||||||
|
defer cancelDial()
|
||||||
|
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_dial_failed account_id=%d status_code=%d err=%s",
|
||||||
|
account.ID,
|
||||||
|
statusCode,
|
||||||
|
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
||||||
|
)
|
||||||
|
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = upstreamConn.Close()
|
||||||
|
}()
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
|
||||||
|
account.ID,
|
||||||
|
statusCode,
|
||||||
|
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("openai ws passthrough upstream connection does not support frame relay")
|
||||||
|
}
|
||||||
|
|
||||||
|
completedTurns := atomic.Int32{}
|
||||||
|
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
||||||
|
Ctx: ctx,
|
||||||
|
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
|
||||||
|
UpstreamConn: upstreamFrameConn,
|
||||||
|
FirstClientMessage: firstClientMessage,
|
||||||
|
Options: openaiwsv2.RelayOptions{
|
||||||
|
WriteTimeout: s.openAIWSWriteTimeout(),
|
||||||
|
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
|
||||||
|
FirstMessageType: coderws.MessageText,
|
||||||
|
OnUsageParseFailure: func(eventType string, usageRaw string) {
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"usage_parse_failed event_type=%s usage_raw=%s",
|
||||||
|
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
|
||||||
|
turnNo := int(completedTurns.Add(1))
|
||||||
|
turnResult := &OpenAIForwardResult{
|
||||||
|
RequestID: turn.RequestID,
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: turn.Usage.InputTokens,
|
||||||
|
OutputTokens: turn.Usage.OutputTokens,
|
||||||
|
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
||||||
|
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||||
|
},
|
||||||
|
Model: turn.RequestModel,
|
||||||
|
Stream: true,
|
||||||
|
OpenAIWSMode: true,
|
||||||
|
Duration: turn.Duration,
|
||||||
|
FirstTokenMs: turn.FirstTokenMs,
|
||||||
|
}
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
|
||||||
|
account.ID,
|
||||||
|
turnNo,
|
||||||
|
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
|
||||||
|
turnResult.Duration.Milliseconds(),
|
||||||
|
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
|
||||||
|
turnResult.Usage.InputTokens,
|
||||||
|
turnResult.Usage.OutputTokens,
|
||||||
|
turnResult.Usage.CacheReadInputTokens,
|
||||||
|
)
|
||||||
|
if hooks != nil && hooks.AfterTurn != nil {
|
||||||
|
hooks.AfterTurn(turnNo, turnResult, nil)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
|
||||||
|
account.ID,
|
||||||
|
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
|
||||||
|
event.PayloadBytes,
|
||||||
|
event.Graceful,
|
||||||
|
event.WroteDownstream,
|
||||||
|
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
result := &OpenAIForwardResult{
|
||||||
|
RequestID: relayResult.RequestID,
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: relayResult.Usage.InputTokens,
|
||||||
|
OutputTokens: relayResult.Usage.OutputTokens,
|
||||||
|
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
||||||
|
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||||
|
},
|
||||||
|
Model: relayResult.RequestModel,
|
||||||
|
Stream: true,
|
||||||
|
OpenAIWSMode: true,
|
||||||
|
Duration: relayResult.Duration,
|
||||||
|
FirstTokenMs: relayResult.FirstTokenMs,
|
||||||
|
}
|
||||||
|
|
||||||
|
turnCount := int(completedTurns.Load())
|
||||||
|
if relayExit == nil {
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
||||||
|
account.ID,
|
||||||
|
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
|
||||||
|
result.Duration.Milliseconds(),
|
||||||
|
relayResult.ClientToUpstreamFrames,
|
||||||
|
relayResult.UpstreamToClientFrames,
|
||||||
|
relayResult.DroppedDownstreamFrames,
|
||||||
|
turnCount,
|
||||||
|
)
|
||||||
|
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
|
||||||
|
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
|
||||||
|
hooks.AfterTurn(1, result, nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
||||||
|
account.ID,
|
||||||
|
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
|
||||||
|
relayExit.WroteDownstream,
|
||||||
|
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
|
||||||
|
result.Duration.Milliseconds(),
|
||||||
|
relayResult.ClientToUpstreamFrames,
|
||||||
|
relayResult.UpstreamToClientFrames,
|
||||||
|
relayResult.DroppedDownstreamFrames,
|
||||||
|
turnCount,
|
||||||
|
)
|
||||||
|
|
||||||
|
relayErr := relayExit.Err
|
||||||
|
if relayExit.Stage == "idle_timeout" {
|
||||||
|
relayErr = NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
"client websocket idle timeout",
|
||||||
|
relayErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
turnErr := wrapOpenAIWSIngressTurnError(
|
||||||
|
relayExit.Stage,
|
||||||
|
relayErr,
|
||||||
|
relayExit.WroteDownstream,
|
||||||
|
)
|
||||||
|
if hooks != nil && hooks.AfterTurn != nil {
|
||||||
|
hooks.AfterTurn(turnCount+1, nil, turnErr)
|
||||||
|
}
|
||||||
|
return turnErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
|
||||||
|
err error,
|
||||||
|
statusCode int,
|
||||||
|
handshakeHeaders http.Header,
|
||||||
|
) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
wrappedErr := err
|
||||||
|
var dialErr *openAIWSDialError
|
||||||
|
if !errors.As(err, &dialErr) {
|
||||||
|
wrappedErr = &openAIWSDialError{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusTryAgainLater,
|
||||||
|
"upstream websocket connect timeout",
|
||||||
|
wrappedErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusTooManyRequests {
|
||||||
|
return NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusTryAgainLater,
|
||||||
|
"upstream websocket is busy, please retry later",
|
||||||
|
wrappedErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||||
|
return NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
"upstream websocket authentication failed",
|
||||||
|
wrappedErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
|
||||||
|
return NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
"upstream websocket handshake rejected",
|
||||||
|
wrappedErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
|
||||||
|
switch msgType {
|
||||||
|
case coderws.MessageText:
|
||||||
|
return "text"
|
||||||
|
case coderws.MessageBinary:
|
||||||
|
return "binary"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown(%d)", msgType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayErrorText(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
|
||||||
|
if firstTokenMs == nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return *firstTokenMs
|
||||||
|
}
|
||||||
|
|
||||||
|
func logOpenAIWSV2Passthrough(format string, args ...any) {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.openai_ws_v2",
|
||||||
|
"[OpenAI WS v2 passthrough] %s "+format,
|
||||||
|
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -31,6 +31,10 @@ func (s *OpsService) GetDashboardOverview(ctx context.Context, filter *OpsDashbo
|
|||||||
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
|
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
|
||||||
|
|
||||||
overview, err := s.opsRepo.GetDashboardOverview(ctx, filter)
|
overview, err := s.opsRepo.GetDashboardOverview(ctx, filter)
|
||||||
|
if err != nil && shouldFallbackOpsPreagg(filter, err) {
|
||||||
|
rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw)
|
||||||
|
overview, err = s.opsRepo.GetDashboardOverview(ctx, rawFilter)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrOpsPreaggregatedNotPopulated) {
|
if errors.Is(err, ErrOpsPreaggregatedNotPopulated) {
|
||||||
return nil, infraerrors.Conflict("OPS_PREAGG_NOT_READY", "Pre-aggregated ops metrics are not populated yet")
|
return nil, infraerrors.Conflict("OPS_PREAGG_NOT_READY", "Pre-aggregated ops metrics are not populated yet")
|
||||||
|
|||||||
@@ -22,7 +22,14 @@ func (s *OpsService) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilt
|
|||||||
if filter.StartTime.After(filter.EndTime) {
|
if filter.StartTime.After(filter.EndTime) {
|
||||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||||
}
|
}
|
||||||
return s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds)
|
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
|
||||||
|
|
||||||
|
result, err := s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds)
|
||||||
|
if err != nil && shouldFallbackOpsPreagg(filter, err) {
|
||||||
|
rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw)
|
||||||
|
return s.opsRepo.GetErrorTrend(ctx, rawFilter, bucketSeconds)
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) {
|
func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) {
|
||||||
@@ -41,5 +48,12 @@ func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashbo
|
|||||||
if filter.StartTime.After(filter.EndTime) {
|
if filter.StartTime.After(filter.EndTime) {
|
||||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||||
}
|
}
|
||||||
return s.opsRepo.GetErrorDistribution(ctx, filter)
|
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
|
||||||
|
|
||||||
|
result, err := s.opsRepo.GetErrorDistribution(ctx, filter)
|
||||||
|
if err != nil && shouldFallbackOpsPreagg(filter, err) {
|
||||||
|
rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw)
|
||||||
|
return s.opsRepo.GetErrorDistribution(ctx, rawFilter)
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,5 +22,12 @@ func (s *OpsService) GetLatencyHistogram(ctx context.Context, filter *OpsDashboa
|
|||||||
if filter.StartTime.After(filter.EndTime) {
|
if filter.StartTime.After(filter.EndTime) {
|
||||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||||
}
|
}
|
||||||
return s.opsRepo.GetLatencyHistogram(ctx, filter)
|
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
|
||||||
|
|
||||||
|
result, err := s.opsRepo.GetLatencyHistogram(ctx, filter)
|
||||||
|
if err != nil && shouldFallbackOpsPreagg(filter, err) {
|
||||||
|
rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw)
|
||||||
|
return s.opsRepo.GetLatencyHistogram(ctx, rawFilter)
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,3 +38,18 @@ func (m OpsQueryMode) IsValid() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldFallbackOpsPreagg(filter *OpsDashboardFilter, err error) bool {
|
||||||
|
return filter != nil &&
|
||||||
|
filter.QueryMode == OpsQueryModeAuto &&
|
||||||
|
errors.Is(err, ErrOpsPreaggregatedNotPopulated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneOpsFilterWithMode(filter *OpsDashboardFilter, mode OpsQueryMode) *OpsDashboardFilter {
|
||||||
|
if filter == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := *filter
|
||||||
|
cloned.QueryMode = mode
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|||||||
66
backend/internal/service/ops_query_mode_test.go
Normal file
66
backend/internal/service/ops_query_mode_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestShouldFallbackOpsPreagg(t *testing.T) {
|
||||||
|
preaggErr := ErrOpsPreaggregatedNotPopulated
|
||||||
|
otherErr := errors.New("some other error")
|
||||||
|
|
||||||
|
autoFilter := &OpsDashboardFilter{QueryMode: OpsQueryModeAuto}
|
||||||
|
rawFilter := &OpsDashboardFilter{QueryMode: OpsQueryModeRaw}
|
||||||
|
preaggFilter := &OpsDashboardFilter{QueryMode: OpsQueryModePreagg}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filter *OpsDashboardFilter
|
||||||
|
err error
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"auto mode + preagg error => fallback", autoFilter, preaggErr, true},
|
||||||
|
{"auto mode + other error => no fallback", autoFilter, otherErr, false},
|
||||||
|
{"auto mode + nil error => no fallback", autoFilter, nil, false},
|
||||||
|
{"raw mode + preagg error => no fallback", rawFilter, preaggErr, false},
|
||||||
|
{"preagg mode + preagg error => no fallback", preaggFilter, preaggErr, false},
|
||||||
|
{"nil filter => no fallback", nil, preaggErr, false},
|
||||||
|
{"wrapped preagg error => fallback", autoFilter, errors.Join(preaggErr, otherErr), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := shouldFallbackOpsPreagg(tc.filter, tc.err)
|
||||||
|
require.Equal(t, tc.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneOpsFilterWithMode(t *testing.T) {
|
||||||
|
t.Run("nil filter returns nil", func(t *testing.T) {
|
||||||
|
require.Nil(t, cloneOpsFilterWithMode(nil, OpsQueryModeRaw))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloned filter has new mode", func(t *testing.T) {
|
||||||
|
groupID := int64(42)
|
||||||
|
original := &OpsDashboardFilter{
|
||||||
|
StartTime: time.Now(),
|
||||||
|
EndTime: time.Now().Add(time.Hour),
|
||||||
|
Platform: "anthropic",
|
||||||
|
GroupID: &groupID,
|
||||||
|
QueryMode: OpsQueryModeAuto,
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := cloneOpsFilterWithMode(original, OpsQueryModeRaw)
|
||||||
|
require.Equal(t, OpsQueryModeRaw, cloned.QueryMode)
|
||||||
|
require.Equal(t, OpsQueryModeAuto, original.QueryMode, "original should not be modified")
|
||||||
|
require.Equal(t, original.Platform, cloned.Platform)
|
||||||
|
require.Equal(t, original.StartTime, cloned.StartTime)
|
||||||
|
require.Equal(t, original.GroupID, cloned.GroupID)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -22,5 +22,13 @@ func (s *OpsService) GetThroughputTrend(ctx context.Context, filter *OpsDashboar
|
|||||||
if filter.StartTime.After(filter.EndTime) {
|
if filter.StartTime.After(filter.EndTime) {
|
||||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||||
}
|
}
|
||||||
return s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds)
|
|
||||||
|
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
|
||||||
|
|
||||||
|
result, err := s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds)
|
||||||
|
if err != nil && shouldFallbackOpsPreagg(filter, err) {
|
||||||
|
rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw)
|
||||||
|
return s.opsRepo.GetThroughputTrend(ctx, rawFilter, bucketSeconds)
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -676,7 +676,17 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 没有重置时间,使用默认5分钟
|
// Anthropic 平台:没有限流重置时间的 429 可能是非真实限流(如 Extra usage required),
|
||||||
|
// 不标记账号限流状态,直接透传错误给客户端
|
||||||
|
if account.Platform == PlatformAnthropic {
|
||||||
|
slog.Warn("rate_limit_429_no_reset_time_skipped",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"platform", account.Platform,
|
||||||
|
"reason", "no rate limit reset time in headers, likely not a real rate limit")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 其他平台:没有重置时间,使用默认5分钟
|
||||||
resetAt := time.Now().Add(5 * time.Minute)
|
resetAt := time.Now().Add(5 * time.Minute)
|
||||||
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
|
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||||
@@ -1081,6 +1091,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
|
|||||||
if !account.IsTempUnschedulableEnabled() {
|
if !account.IsTempUnschedulableEnabled() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
// 401 首次命中可临时不可调度(给 token 刷新窗口);
|
||||||
|
// 若历史上已因 401 进入过临时不可调度,则本次应升级为 error(返回 false 交由默认错误逻辑处理)。
|
||||||
|
if statusCode == http.StatusUnauthorized {
|
||||||
|
reason := account.TempUnschedulableReason
|
||||||
|
// 缓存可能没有 reason,从 DB 回退读取
|
||||||
|
if reason == "" {
|
||||||
|
if dbAcc, err := s.accountRepo.GetByID(ctx, account.ID); err == nil && dbAcc != nil {
|
||||||
|
reason = dbAcc.TempUnschedulableReason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if wasTempUnschedByStatusCode(reason, statusCode) {
|
||||||
|
slog.Info("401_escalated_to_error", "account_id", account.ID,
|
||||||
|
"reason", "previous temp-unschedulable was also 401")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
rules := account.GetTempUnschedulableRules()
|
rules := account.GetTempUnschedulableRules()
|
||||||
if len(rules) == 0 {
|
if len(rules) == 0 {
|
||||||
return false
|
return false
|
||||||
@@ -1112,6 +1138,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func wasTempUnschedByStatusCode(reason string, statusCode int) bool {
|
||||||
|
if statusCode <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
reason = strings.TrimSpace(reason)
|
||||||
|
if reason == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var state TempUnschedState
|
||||||
|
if err := json.Unmarshal([]byte(reason), &state); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return state.StatusCode == statusCode
|
||||||
|
}
|
||||||
|
|
||||||
func matchTempUnschedKeyword(bodyLower string, keywords []string) string {
|
func matchTempUnschedKeyword(bodyLower string, keywords []string) string {
|
||||||
if bodyLower == "" {
|
if bodyLower == "" {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -0,0 +1,119 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dbFallbackRepoStub extends errorPolicyRepoStub with a configurable DB account
|
||||||
|
// returned by GetByID, simulating cache miss + DB fallback.
|
||||||
|
type dbFallbackRepoStub struct {
|
||||||
|
errorPolicyRepoStub
|
||||||
|
dbAccount *Account // returned by GetByID when non-nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||||
|
if r.dbAccount != nil && r.dbAccount.ID == id {
|
||||||
|
return r.dbAccount, nil
|
||||||
|
}
|
||||||
|
return nil, nil // not found, no error
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
|
||||||
|
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
|
||||||
|
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
|
||||||
|
repo := &dbFallbackRepoStub{
|
||||||
|
dbAccount: &Account{
|
||||||
|
ID: 20,
|
||||||
|
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 20,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
TempUnschedulableReason: "", // cache miss — reason is empty
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(401),
|
||||||
|
"keywords": []any{"unauthorized"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||||
|
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
|
||||||
|
// Scenario: cache account has empty TempUnschedulableReason,
|
||||||
|
// DB also has no previous 401 record → should NOT escalate (first hit → temp unscheduled).
|
||||||
|
repo := &dbFallbackRepoStub{
|
||||||
|
dbAccount: &Account{
|
||||||
|
ID: 21,
|
||||||
|
TempUnschedulableReason: "", // DB also empty
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 21,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
TempUnschedulableReason: "",
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(401),
|
||||||
|
"keywords": []any{"unauthorized"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||||
|
require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with no DB record should temp-unschedule")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckErrorPolicy_401_DBFallback_DBError_FirstHit(t *testing.T) {
|
||||||
|
// Scenario: cache account has empty TempUnschedulableReason,
|
||||||
|
// DB lookup returns nil (not found) → should treat as first hit → temp unscheduled.
|
||||||
|
repo := &dbFallbackRepoStub{
|
||||||
|
dbAccount: nil, // GetByID returns nil, nil
|
||||||
|
}
|
||||||
|
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 22,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
TempUnschedulableReason: "",
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(401),
|
||||||
|
"keywords": []any{"unauthorized"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||||
|
require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with DB not found should temp-unschedule")
|
||||||
|
}
|
||||||
123
backend/internal/service/registration_email_policy.go
Normal file
123
backend/internal/service/registration_email_policy.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var registrationEmailDomainPattern = regexp.MustCompile(
|
||||||
|
`^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)+$`,
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegistrationEmailSuffix extracts normalized suffix in "@domain" form.
|
||||||
|
func RegistrationEmailSuffix(email string) string {
|
||||||
|
_, domain, ok := splitEmailForPolicy(email)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "@" + domain
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRegistrationEmailSuffixAllowed checks whether an email is allowed by suffix whitelist.
|
||||||
|
// Empty whitelist means allow all.
|
||||||
|
func IsRegistrationEmailSuffixAllowed(email string, whitelist []string) bool {
|
||||||
|
if len(whitelist) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
suffix := RegistrationEmailSuffix(email)
|
||||||
|
if suffix == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, allowed := range whitelist {
|
||||||
|
if suffix == allowed {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeRegistrationEmailSuffixWhitelist normalizes and validates suffix whitelist items.
|
||||||
|
func NormalizeRegistrationEmailSuffixWhitelist(raw []string) ([]string, error) {
|
||||||
|
return normalizeRegistrationEmailSuffixWhitelist(raw, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseRegistrationEmailSuffixWhitelist parses persisted JSON into normalized suffixes.
|
||||||
|
// Invalid entries are ignored to keep old misconfigurations from breaking runtime reads.
|
||||||
|
func ParseRegistrationEmailSuffixWhitelist(raw string) []string {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
var items []string
|
||||||
|
if err := json.Unmarshal([]byte(raw), &items); err != nil {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
normalized, _ := normalizeRegistrationEmailSuffixWhitelist(items, false)
|
||||||
|
if len(normalized) == 0 {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeRegistrationEmailSuffixWhitelist(raw []string, strict bool) ([]string, error) {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]struct{}, len(raw))
|
||||||
|
out := make([]string, 0, len(raw))
|
||||||
|
for _, item := range raw {
|
||||||
|
normalized, err := normalizeRegistrationEmailSuffix(item)
|
||||||
|
if err != nil {
|
||||||
|
if strict {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if normalized == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[normalized]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[normalized] = struct{}{}
|
||||||
|
out = append(out, normalized)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeRegistrationEmailSuffix(raw string) (string, error) {
|
||||||
|
value := strings.ToLower(strings.TrimSpace(raw))
|
||||||
|
if value == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
domain := value
|
||||||
|
if strings.Contains(value, "@") {
|
||||||
|
if !strings.HasPrefix(value, "@") || strings.Count(value, "@") != 1 {
|
||||||
|
return "", fmt.Errorf("invalid email suffix: %q", raw)
|
||||||
|
}
|
||||||
|
domain = strings.TrimPrefix(value, "@")
|
||||||
|
}
|
||||||
|
|
||||||
|
if domain == "" || strings.Contains(domain, "@") || !registrationEmailDomainPattern.MatchString(domain) {
|
||||||
|
return "", fmt.Errorf("invalid email suffix: %q", raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
return "@" + domain, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitEmailForPolicy(raw string) (local string, domain string, ok bool) {
|
||||||
|
email := strings.ToLower(strings.TrimSpace(raw))
|
||||||
|
local, domain, found := strings.Cut(email, "@")
|
||||||
|
if !found || local == "" || domain == "" || strings.Contains(domain, "@") {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
return local, domain, true
|
||||||
|
}
|
||||||
31
backend/internal/service/registration_email_policy_test.go
Normal file
31
backend/internal/service/registration_email_policy_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeRegistrationEmailSuffixWhitelist(t *testing.T) {
|
||||||
|
got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar "})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []string{"@example.com", "@foo.bar"}, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeRegistrationEmailSuffixWhitelist_Invalid(t *testing.T) {
|
||||||
|
_, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"@invalid_domain"})
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRegistrationEmailSuffixWhitelist(t *testing.T) {
|
||||||
|
got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","@invalid_domain"]`)
|
||||||
|
require.Equal(t, []string{"@example.com", "@foo.bar"}, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsRegistrationEmailSuffixAllowed(t *testing.T) {
|
||||||
|
require.True(t, IsRegistrationEmailSuffixAllowed("user@example.com", []string{"@example.com"}))
|
||||||
|
require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.example.com", []string{"@example.com"}))
|
||||||
|
require.True(t, IsRegistrationEmailSuffixAllowed("user@any.com", []string{}))
|
||||||
|
}
|
||||||
51
backend/internal/service/scheduled_test_port.go
Normal file
51
backend/internal/service/scheduled_test_port.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ScheduledTestPlan represents a scheduled test plan domain model.
|
||||||
|
type ScheduledTestPlan struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
AccountID int64 `json:"account_id"`
|
||||||
|
ModelID string `json:"model_id"`
|
||||||
|
CronExpression string `json:"cron_expression"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
MaxResults int `json:"max_results"`
|
||||||
|
LastRunAt *time.Time `json:"last_run_at"`
|
||||||
|
NextRunAt *time.Time `json:"next_run_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScheduledTestResult represents a single test execution result.
|
||||||
|
type ScheduledTestResult struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
PlanID int64 `json:"plan_id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
ResponseText string `json:"response_text"`
|
||||||
|
ErrorMessage string `json:"error_message"`
|
||||||
|
LatencyMs int64 `json:"latency_ms"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
FinishedAt time.Time `json:"finished_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScheduledTestPlanRepository defines the data access interface for test plans.
|
||||||
|
type ScheduledTestPlanRepository interface {
|
||||||
|
Create(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error)
|
||||||
|
GetByID(ctx context.Context, id int64) (*ScheduledTestPlan, error)
|
||||||
|
ListByAccountID(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error)
|
||||||
|
ListDue(ctx context.Context, now time.Time) ([]*ScheduledTestPlan, error)
|
||||||
|
Update(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error)
|
||||||
|
Delete(ctx context.Context, id int64) error
|
||||||
|
UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScheduledTestResultRepository defines the data access interface for test results.
|
||||||
|
type ScheduledTestResultRepository interface {
|
||||||
|
Create(ctx context.Context, result *ScheduledTestResult) (*ScheduledTestResult, error)
|
||||||
|
ListByPlanID(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error)
|
||||||
|
PruneOldResults(ctx context.Context, planID int64, keepCount int) error
|
||||||
|
}
|
||||||
139
backend/internal/service/scheduled_test_runner_service.go
Normal file
139
backend/internal/service/scheduled_test_runner_service.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/robfig/cron/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
const scheduledTestDefaultMaxWorkers = 10
|
||||||
|
|
||||||
|
// ScheduledTestRunnerService periodically scans due test plans and executes them.
|
||||||
|
type ScheduledTestRunnerService struct {
|
||||||
|
planRepo ScheduledTestPlanRepository
|
||||||
|
scheduledSvc *ScheduledTestService
|
||||||
|
accountTestSvc *AccountTestService
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
cron *cron.Cron
|
||||||
|
startOnce sync.Once
|
||||||
|
stopOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewScheduledTestRunnerService creates a new runner.
|
||||||
|
func NewScheduledTestRunnerService(
|
||||||
|
planRepo ScheduledTestPlanRepository,
|
||||||
|
scheduledSvc *ScheduledTestService,
|
||||||
|
accountTestSvc *AccountTestService,
|
||||||
|
cfg *config.Config,
|
||||||
|
) *ScheduledTestRunnerService {
|
||||||
|
return &ScheduledTestRunnerService{
|
||||||
|
planRepo: planRepo,
|
||||||
|
scheduledSvc: scheduledSvc,
|
||||||
|
accountTestSvc: accountTestSvc,
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins the cron ticker (every minute).
|
||||||
|
func (s *ScheduledTestRunnerService) Start() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.startOnce.Do(func() {
|
||||||
|
loc := time.Local
|
||||||
|
if s.cfg != nil {
|
||||||
|
if parsed, err := time.LoadLocation(s.cfg.Timezone); err == nil && parsed != nil {
|
||||||
|
loc = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c := cron.New(cron.WithParser(scheduledTestCronParser), cron.WithLocation(loc))
|
||||||
|
_, err := c.AddFunc("* * * * *", func() { s.runScheduled() })
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] not started (invalid schedule): %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.cron = c
|
||||||
|
s.cron.Start()
|
||||||
|
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] started (tick=every minute)")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the cron scheduler.
|
||||||
|
func (s *ScheduledTestRunnerService) Stop() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stopOnce.Do(func() {
|
||||||
|
if s.cron != nil {
|
||||||
|
ctx := s.cron.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] cron stop timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScheduledTestRunnerService) runScheduled() {
|
||||||
|
// Delay 10s so execution lands at ~:10 of each minute instead of :00.
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
plans, err := s.planRepo.ListDue(ctx, now)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] ListDue error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(plans) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] found %d due plans", len(plans))
|
||||||
|
|
||||||
|
sem := make(chan struct{}, scheduledTestDefaultMaxWorkers)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for _, plan := range plans {
|
||||||
|
sem <- struct{}{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(p *ScheduledTestPlan) {
|
||||||
|
defer wg.Done()
|
||||||
|
defer func() { <-sem }()
|
||||||
|
s.runOnePlan(ctx, p)
|
||||||
|
}(plan)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *ScheduledTestPlan) {
|
||||||
|
result, err := s.accountTestSvc.RunTestBackground(ctx, plan.AccountID, plan.ModelID)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d RunTestBackground error: %v", plan.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.scheduledSvc.SaveResult(ctx, plan.ID, plan.MaxResults, result); err != nil {
|
||||||
|
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d SaveResult error: %v", plan.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d computeNextRun error: %v", plan.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.planRepo.UpdateAfterRun(ctx, plan.ID, time.Now(), nextRun); err != nil {
|
||||||
|
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d UpdateAfterRun error: %v", plan.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
94
backend/internal/service/scheduled_test_service.go
Normal file
94
backend/internal/service/scheduled_test_service.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/robfig/cron/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var scheduledTestCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||||
|
|
||||||
|
// ScheduledTestService provides CRUD operations for scheduled test plans and results.
|
||||||
|
type ScheduledTestService struct {
|
||||||
|
planRepo ScheduledTestPlanRepository
|
||||||
|
resultRepo ScheduledTestResultRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewScheduledTestService creates a new ScheduledTestService.
|
||||||
|
func NewScheduledTestService(
|
||||||
|
planRepo ScheduledTestPlanRepository,
|
||||||
|
resultRepo ScheduledTestResultRepository,
|
||||||
|
) *ScheduledTestService {
|
||||||
|
return &ScheduledTestService{
|
||||||
|
planRepo: planRepo,
|
||||||
|
resultRepo: resultRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePlan validates the cron expression, computes next_run_at, and persists the plan.
|
||||||
|
func (s *ScheduledTestService) CreatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) {
|
||||||
|
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid cron expression: %w", err)
|
||||||
|
}
|
||||||
|
plan.NextRunAt = &nextRun
|
||||||
|
|
||||||
|
if plan.MaxResults <= 0 {
|
||||||
|
plan.MaxResults = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.planRepo.Create(ctx, plan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPlan retrieves a plan by ID.
|
||||||
|
func (s *ScheduledTestService) GetPlan(ctx context.Context, id int64) (*ScheduledTestPlan, error) {
|
||||||
|
return s.planRepo.GetByID(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListPlansByAccount returns all plans for a given account.
|
||||||
|
func (s *ScheduledTestService) ListPlansByAccount(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error) {
|
||||||
|
return s.planRepo.ListByAccountID(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePlan validates cron and updates the plan.
|
||||||
|
func (s *ScheduledTestService) UpdatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) {
|
||||||
|
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid cron expression: %w", err)
|
||||||
|
}
|
||||||
|
plan.NextRunAt = &nextRun
|
||||||
|
|
||||||
|
return s.planRepo.Update(ctx, plan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePlan removes a plan and its results (via CASCADE).
|
||||||
|
func (s *ScheduledTestService) DeletePlan(ctx context.Context, id int64) error {
|
||||||
|
return s.planRepo.Delete(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListResults returns the most recent results for a plan.
|
||||||
|
func (s *ScheduledTestService) ListResults(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
return s.resultRepo.ListByPlanID(ctx, planID, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveResult inserts a result and prunes old entries beyond maxResults.
|
||||||
|
func (s *ScheduledTestService) SaveResult(ctx context.Context, planID int64, maxResults int, result *ScheduledTestResult) error {
|
||||||
|
result.PlanID = planID
|
||||||
|
if _, err := s.resultRepo.Create(ctx, result); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.resultRepo.PruneOldResults(ctx, planID, maxResults)
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeNextRun(cronExpr string, from time.Time) (time.Time, error) {
|
||||||
|
sched, err := scheduledTestCronParser.Parse(cronExpr)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, err
|
||||||
|
}
|
||||||
|
return sched.Next(from), nil
|
||||||
|
}
|
||||||
@@ -108,6 +108,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
keys := []string{
|
keys := []string{
|
||||||
SettingKeyRegistrationEnabled,
|
SettingKeyRegistrationEnabled,
|
||||||
SettingKeyEmailVerifyEnabled,
|
SettingKeyEmailVerifyEnabled,
|
||||||
|
SettingKeyRegistrationEmailSuffixWhitelist,
|
||||||
SettingKeyPromoCodeEnabled,
|
SettingKeyPromoCodeEnabled,
|
||||||
SettingKeyPasswordResetEnabled,
|
SettingKeyPasswordResetEnabled,
|
||||||
SettingKeyInvitationCodeEnabled,
|
SettingKeyInvitationCodeEnabled,
|
||||||
@@ -144,29 +145,33 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
// Password reset requires email verification to be enabled
|
// Password reset requires email verification to be enabled
|
||||||
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
||||||
passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true"
|
passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true"
|
||||||
|
registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist(
|
||||||
|
settings[SettingKeyRegistrationEmailSuffixWhitelist],
|
||||||
|
)
|
||||||
|
|
||||||
return &PublicSettings{
|
return &PublicSettings{
|
||||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||||
EmailVerifyEnabled: emailVerifyEnabled,
|
EmailVerifyEnabled: emailVerifyEnabled,
|
||||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist,
|
||||||
PasswordResetEnabled: passwordResetEnabled,
|
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
PasswordResetEnabled: passwordResetEnabled,
|
||||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||||
SiteLogo: settings[SettingKeySiteLogo],
|
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
SiteLogo: settings[SettingKeySiteLogo],
|
||||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||||
ContactInfo: settings[SettingKeyContactInfo],
|
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||||
DocURL: settings[SettingKeyDocURL],
|
ContactInfo: settings[SettingKeyContactInfo],
|
||||||
HomeContent: settings[SettingKeyHomeContent],
|
DocURL: settings[SettingKeyDocURL],
|
||||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
HomeContent: settings[SettingKeyHomeContent],
|
||||||
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
||||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||||
|
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,51 +201,53 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
|||||||
|
|
||||||
// Return a struct that matches the frontend's expected format
|
// Return a struct that matches the frontend's expected format
|
||||||
return &struct {
|
return &struct {
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
TotpEnabled bool `json:"totp_enabled"`
|
||||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||||
SiteName string `json:"site_name"`
|
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||||
SiteLogo string `json:"site_logo,omitempty"`
|
SiteName string `json:"site_name"`
|
||||||
SiteSubtitle string `json:"site_subtitle,omitempty"`
|
SiteLogo string `json:"site_logo,omitempty"`
|
||||||
APIBaseURL string `json:"api_base_url,omitempty"`
|
SiteSubtitle string `json:"site_subtitle,omitempty"`
|
||||||
ContactInfo string `json:"contact_info,omitempty"`
|
APIBaseURL string `json:"api_base_url,omitempty"`
|
||||||
DocURL string `json:"doc_url,omitempty"`
|
ContactInfo string `json:"contact_info,omitempty"`
|
||||||
HomeContent string `json:"home_content,omitempty"`
|
DocURL string `json:"doc_url,omitempty"`
|
||||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
HomeContent string `json:"home_content,omitempty"`
|
||||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
|
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
|
||||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||||
Version string `json:"version,omitempty"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
}{
|
}{
|
||||||
RegistrationEnabled: settings.RegistrationEnabled,
|
RegistrationEnabled: settings.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TurnstileEnabled: settings.TurnstileEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
TurnstileEnabled: settings.TurnstileEnabled,
|
||||||
SiteName: settings.SiteName,
|
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||||
SiteLogo: settings.SiteLogo,
|
SiteName: settings.SiteName,
|
||||||
SiteSubtitle: settings.SiteSubtitle,
|
SiteLogo: settings.SiteLogo,
|
||||||
APIBaseURL: settings.APIBaseURL,
|
SiteSubtitle: settings.SiteSubtitle,
|
||||||
ContactInfo: settings.ContactInfo,
|
APIBaseURL: settings.APIBaseURL,
|
||||||
DocURL: settings.DocURL,
|
ContactInfo: settings.ContactInfo,
|
||||||
HomeContent: settings.HomeContent,
|
DocURL: settings.DocURL,
|
||||||
HideCcsImportButton: settings.HideCcsImportButton,
|
HomeContent: settings.HomeContent,
|
||||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
HideCcsImportButton: settings.HideCcsImportButton,
|
||||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||||
SoraClientEnabled: settings.SoraClientEnabled,
|
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||||
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
|
SoraClientEnabled: settings.SoraClientEnabled,
|
||||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
|
||||||
Version: s.version,
|
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||||
|
Version: s.version,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -356,12 +363,25 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
|||||||
if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
|
if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist)
|
||||||
|
if err != nil {
|
||||||
|
return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
|
||||||
|
}
|
||||||
|
if normalizedWhitelist == nil {
|
||||||
|
normalizedWhitelist = []string{}
|
||||||
|
}
|
||||||
|
settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist
|
||||||
|
|
||||||
updates := make(map[string]string)
|
updates := make(map[string]string)
|
||||||
|
|
||||||
// 注册设置
|
// 注册设置
|
||||||
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
|
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
|
||||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||||
|
registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal registration email suffix whitelist: %w", err)
|
||||||
|
}
|
||||||
|
updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON)
|
||||||
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
||||||
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
|
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
|
||||||
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
|
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
|
||||||
@@ -514,6 +534,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
|||||||
return value == "true"
|
return value == "true"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRegistrationEmailSuffixWhitelist returns normalized registration email suffix whitelist.
|
||||||
|
func (s *SettingService) GetRegistrationEmailSuffixWhitelist(ctx context.Context) []string {
|
||||||
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEmailSuffixWhitelist)
|
||||||
|
if err != nil {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
return ParseRegistrationEmailSuffixWhitelist(value)
|
||||||
|
}
|
||||||
|
|
||||||
// IsPromoCodeEnabled 检查是否启用优惠码功能
|
// IsPromoCodeEnabled 检查是否启用优惠码功能
|
||||||
func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
|
func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
|
||||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled)
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled)
|
||||||
@@ -617,20 +646,21 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
|
|
||||||
// 初始化默认设置
|
// 初始化默认设置
|
||||||
defaults := map[string]string{
|
defaults := map[string]string{
|
||||||
SettingKeyRegistrationEnabled: "true",
|
SettingKeyRegistrationEnabled: "true",
|
||||||
SettingKeyEmailVerifyEnabled: "false",
|
SettingKeyEmailVerifyEnabled: "false",
|
||||||
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||||
SettingKeySiteName: "Sub2API",
|
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
||||||
SettingKeySiteLogo: "",
|
SettingKeySiteName: "Sub2API",
|
||||||
SettingKeyPurchaseSubscriptionEnabled: "false",
|
SettingKeySiteLogo: "",
|
||||||
SettingKeyPurchaseSubscriptionURL: "",
|
SettingKeyPurchaseSubscriptionEnabled: "false",
|
||||||
SettingKeySoraClientEnabled: "false",
|
SettingKeyPurchaseSubscriptionURL: "",
|
||||||
SettingKeyCustomMenuItems: "[]",
|
SettingKeySoraClientEnabled: "false",
|
||||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
SettingKeyCustomMenuItems: "[]",
|
||||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||||
SettingKeyDefaultSubscriptions: "[]",
|
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||||
SettingKeySMTPPort: "587",
|
SettingKeyDefaultSubscriptions: "[]",
|
||||||
SettingKeySMTPUseTLS: "false",
|
SettingKeySMTPPort: "587",
|
||||||
|
SettingKeySMTPUseTLS: "false",
|
||||||
// Model fallback defaults
|
// Model fallback defaults
|
||||||
SettingKeyEnableModelFallback: "false",
|
SettingKeyEnableModelFallback: "false",
|
||||||
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
|
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
|
||||||
@@ -661,33 +691,34 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
|
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
|
||||||
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
||||||
result := &SystemSettings{
|
result := &SystemSettings{
|
||||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||||
EmailVerifyEnabled: emailVerifyEnabled,
|
EmailVerifyEnabled: emailVerifyEnabled,
|
||||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
RegistrationEmailSuffixWhitelist: ParseRegistrationEmailSuffixWhitelist(settings[SettingKeyRegistrationEmailSuffixWhitelist]),
|
||||||
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
|
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
|
||||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||||
SMTPHost: settings[SettingKeySMTPHost],
|
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
SMTPHost: settings[SettingKeySMTPHost],
|
||||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||||
SMTPFromName: settings[SettingKeySMTPFromName],
|
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||||
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
|
SMTPFromName: settings[SettingKeySMTPFromName],
|
||||||
SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "",
|
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
|
||||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "",
|
||||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||||
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
|
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
|
||||||
SiteLogo: settings[SettingKeySiteLogo],
|
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
SiteLogo: settings[SettingKeySiteLogo],
|
||||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||||
ContactInfo: settings[SettingKeyContactInfo],
|
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||||
DocURL: settings[SettingKeyDocURL],
|
ContactInfo: settings[SettingKeyContactInfo],
|
||||||
HomeContent: settings[SettingKeyHomeContent],
|
DocURL: settings[SettingKeyDocURL],
|
||||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
HomeContent: settings[SettingKeyHomeContent],
|
||||||
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
||||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||||
|
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析整数类型
|
// 解析整数类型
|
||||||
|
|||||||
64
backend/internal/service/setting_service_public_test.go
Normal file
64
backend/internal/service/setting_service_public_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type settingPublicRepoStub struct {
|
||||||
|
values map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingPublicRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||||
|
panic("unexpected Get call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
|
panic("unexpected GetValue call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingPublicRepoStub) Set(ctx context.Context, key, value string) error {
|
||||||
|
panic("unexpected Set call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||||
|
out := make(map[string]string, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
if value, ok := s.values[key]; ok {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||||
|
panic("unexpected SetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||||
|
panic("unexpected GetAll call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingPublicRepoStub) Delete(ctx context.Context, key string) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelist(t *testing.T) {
|
||||||
|
repo := &settingPublicRepoStub{
|
||||||
|
values: map[string]string{
|
||||||
|
SettingKeyRegistrationEnabled: "true",
|
||||||
|
SettingKeyEmailVerifyEnabled: "true",
|
||||||
|
SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","@invalid_domain",""]`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewSettingService(repo, &config.Config{})
|
||||||
|
|
||||||
|
settings, err := svc.GetPublicSettings(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist)
|
||||||
|
}
|
||||||
@@ -172,6 +172,28 @@ func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGrou
|
|||||||
require.Nil(t, repo.updates)
|
require.Nil(t, repo.updates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Normalized(t *testing.T) {
|
||||||
|
repo := &settingUpdateRepoStub{}
|
||||||
|
svc := NewSettingService(repo, &config.Config{})
|
||||||
|
|
||||||
|
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||||
|
RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar "},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, `["@example.com","@foo.bar"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Invalid(t *testing.T) {
|
||||||
|
repo := &settingUpdateRepoStub{}
|
||||||
|
svc := NewSettingService(repo, &config.Config{})
|
||||||
|
|
||||||
|
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||||
|
RegistrationEmailSuffixWhitelist: []string{"@invalid_domain"},
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, "INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", infraerrors.Reason(err))
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) {
|
func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) {
|
||||||
got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`)
|
got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`)
|
||||||
require.Equal(t, []DefaultSubscriptionSetting{
|
require.Equal(t, []DefaultSubscriptionSetting{
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
type SystemSettings struct {
|
type SystemSettings struct {
|
||||||
RegistrationEnabled bool
|
RegistrationEnabled bool
|
||||||
EmailVerifyEnabled bool
|
EmailVerifyEnabled bool
|
||||||
PromoCodeEnabled bool
|
RegistrationEmailSuffixWhitelist []string
|
||||||
PasswordResetEnabled bool
|
PromoCodeEnabled bool
|
||||||
InvitationCodeEnabled bool
|
PasswordResetEnabled bool
|
||||||
TotpEnabled bool // TOTP 双因素认证
|
InvitationCodeEnabled bool
|
||||||
|
TotpEnabled bool // TOTP 双因素认证
|
||||||
|
|
||||||
SMTPHost string
|
SMTPHost string
|
||||||
SMTPPort int
|
SMTPPort int
|
||||||
@@ -76,22 +77,23 @@ type DefaultSubscriptionSetting struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PublicSettings struct {
|
type PublicSettings struct {
|
||||||
RegistrationEnabled bool
|
RegistrationEnabled bool
|
||||||
EmailVerifyEnabled bool
|
EmailVerifyEnabled bool
|
||||||
PromoCodeEnabled bool
|
RegistrationEmailSuffixWhitelist []string
|
||||||
PasswordResetEnabled bool
|
PromoCodeEnabled bool
|
||||||
InvitationCodeEnabled bool
|
PasswordResetEnabled bool
|
||||||
TotpEnabled bool // TOTP 双因素认证
|
InvitationCodeEnabled bool
|
||||||
TurnstileEnabled bool
|
TotpEnabled bool // TOTP 双因素认证
|
||||||
TurnstileSiteKey string
|
TurnstileEnabled bool
|
||||||
SiteName string
|
TurnstileSiteKey string
|
||||||
SiteLogo string
|
SiteName string
|
||||||
SiteSubtitle string
|
SiteLogo string
|
||||||
APIBaseURL string
|
SiteSubtitle string
|
||||||
ContactInfo string
|
APIBaseURL string
|
||||||
DocURL string
|
ContactInfo string
|
||||||
HomeContent string
|
DocURL string
|
||||||
HideCcsImportButton bool
|
HomeContent string
|
||||||
|
HideCcsImportButton bool
|
||||||
|
|
||||||
PurchaseSubscriptionEnabled bool
|
PurchaseSubscriptionEnabled bool
|
||||||
PurchaseSubscriptionURL string
|
PurchaseSubscriptionURL string
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ type UserListFilters struct {
|
|||||||
Role string // User role filter
|
Role string // User role filter
|
||||||
Search string // Search in email, username
|
Search string // Search in email, username
|
||||||
Attributes map[int64]string // Custom attribute filters: attributeID -> value
|
Attributes map[int64]string // Custom attribute filters: attributeID -> value
|
||||||
|
// IncludeSubscriptions controls whether ListWithFilters should load active subscriptions.
|
||||||
|
// For large datasets this can be expensive; admin list pages should enable it on demand.
|
||||||
|
// nil means not specified (default: load subscriptions for backward compatibility).
|
||||||
|
IncludeSubscriptions *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserRepository interface {
|
type UserRepository interface {
|
||||||
|
|||||||
@@ -274,6 +274,26 @@ func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Co
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideScheduledTestService creates ScheduledTestService.
|
||||||
|
func ProvideScheduledTestService(
|
||||||
|
planRepo ScheduledTestPlanRepository,
|
||||||
|
resultRepo ScheduledTestResultRepository,
|
||||||
|
) *ScheduledTestService {
|
||||||
|
return NewScheduledTestService(planRepo, resultRepo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideScheduledTestRunnerService creates and starts ScheduledTestRunnerService.
|
||||||
|
func ProvideScheduledTestRunnerService(
|
||||||
|
planRepo ScheduledTestPlanRepository,
|
||||||
|
scheduledSvc *ScheduledTestService,
|
||||||
|
accountTestSvc *AccountTestService,
|
||||||
|
cfg *config.Config,
|
||||||
|
) *ScheduledTestRunnerService {
|
||||||
|
svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, cfg)
|
||||||
|
svc.Start()
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
|
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
|
||||||
func ProvideOpsScheduledReportService(
|
func ProvideOpsScheduledReportService(
|
||||||
opsService *OpsService,
|
opsService *OpsService,
|
||||||
@@ -380,4 +400,6 @@ var ProviderSet = wire.NewSet(
|
|||||||
ProvideIdempotencyCoordinator,
|
ProvideIdempotencyCoordinator,
|
||||||
ProvideSystemOperationLockService,
|
ProvideSystemOperationLockService,
|
||||||
ProvideIdempotencyCleanupService,
|
ProvideIdempotencyCleanupService,
|
||||||
|
ProvideScheduledTestService,
|
||||||
|
ProvideScheduledTestRunnerService,
|
||||||
)
|
)
|
||||||
|
|||||||
33
backend/migrations/065_add_search_trgm_indexes.sql
Normal file
33
backend/migrations/065_add_search_trgm_indexes.sql
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
-- Improve admin fuzzy-search performance on large datasets.
|
||||||
|
-- Best effort:
|
||||||
|
-- 1) try enabling pg_trgm
|
||||||
|
-- 2) only create trigram indexes when extension is available
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS pg_trgm;
|
||||||
|
EXCEPTION
|
||||||
|
WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pg_trgm extension not created: %', SQLERRM;
|
||||||
|
END;
|
||||||
|
|
||||||
|
IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_trgm') THEN
|
||||||
|
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_email_trgm
|
||||||
|
ON users USING gin (email gin_trgm_ops)';
|
||||||
|
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_username_trgm
|
||||||
|
ON users USING gin (username gin_trgm_ops)';
|
||||||
|
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_notes_trgm
|
||||||
|
ON users USING gin (notes gin_trgm_ops)';
|
||||||
|
|
||||||
|
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_accounts_name_trgm
|
||||||
|
ON accounts USING gin (name gin_trgm_ops)';
|
||||||
|
|
||||||
|
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_api_keys_key_trgm
|
||||||
|
ON api_keys USING gin ("key" gin_trgm_ops)';
|
||||||
|
EXECUTE 'CREATE INDEX IF NOT EXISTS idx_api_keys_name_trgm
|
||||||
|
ON api_keys USING gin (name gin_trgm_ops)';
|
||||||
|
ELSE
|
||||||
|
RAISE NOTICE 'skip trigram indexes because pg_trgm is unavailable';
|
||||||
|
END IF;
|
||||||
|
END
|
||||||
|
$$;
|
||||||
30
backend/migrations/066_add_scheduled_test_tables.sql
Normal file
30
backend/migrations/066_add_scheduled_test_tables.sql
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
-- 066_add_scheduled_test_tables.sql
|
||||||
|
-- Scheduled account test plans and results
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS scheduled_test_plans (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||||
|
model_id VARCHAR(100) NOT NULL DEFAULT '',
|
||||||
|
cron_expression VARCHAR(100) NOT NULL DEFAULT '*/30 * * * *',
|
||||||
|
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||||
|
max_results INT NOT NULL DEFAULT 50,
|
||||||
|
last_run_at TIMESTAMPTZ,
|
||||||
|
next_run_at TIMESTAMPTZ,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_stp_account_id ON scheduled_test_plans(account_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_stp_enabled_next_run ON scheduled_test_plans(enabled, next_run_at) WHERE enabled = true;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS scheduled_test_results (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
plan_id BIGINT NOT NULL REFERENCES scheduled_test_plans(id) ON DELETE CASCADE,
|
||||||
|
status VARCHAR(20) NOT NULL DEFAULT 'success',
|
||||||
|
response_text TEXT NOT NULL DEFAULT '',
|
||||||
|
error_message TEXT NOT NULL DEFAULT '',
|
||||||
|
latency_ms BIGINT NOT NULL DEFAULT 0,
|
||||||
|
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
finished_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_str_plan_created ON scheduled_test_results(plan_id, created_at DESC);
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
# 本地构建镜像的快速脚本,避免在命令行反复输入构建参数。
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
||||||
|
|
||||||
docker build -t sub2api:latest \
|
|
||||||
--build-arg GOPROXY=https://goproxy.cn,direct \
|
|
||||||
--build-arg GOSUMDB=sum.golang.google.cn \
|
|
||||||
-f "${SCRIPT_DIR}/Dockerfile" \
|
|
||||||
"${SCRIPT_DIR}"
|
|
||||||
@@ -112,7 +112,7 @@ POSTGRES_DB=sub2api
|
|||||||
DATABASE_PORT=5432
|
DATABASE_PORT=5432
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# PostgreSQL 服务端参数(可选;主要用于 deploy/docker-compose-aicodex.yml)
|
# PostgreSQL 服务端参数(可选)
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# POSTGRES_MAX_CONNECTIONS:PostgreSQL 服务端允许的最大连接数。
|
# POSTGRES_MAX_CONNECTIONS:PostgreSQL 服务端允许的最大连接数。
|
||||||
# 必须 >=(所有 Sub2API 实例的 DATABASE_MAX_OPEN_CONNS 之和)+ 预留余量(例如 20%)。
|
# 必须 >=(所有 Sub2API 实例的 DATABASE_MAX_OPEN_CONNS 之和)+ 预留余量(例如 20%)。
|
||||||
@@ -163,7 +163,7 @@ REDIS_PORT=6379
|
|||||||
# Leave empty for no password (default for local development)
|
# Leave empty for no password (default for local development)
|
||||||
REDIS_PASSWORD=
|
REDIS_PASSWORD=
|
||||||
REDIS_DB=0
|
REDIS_DB=0
|
||||||
# Redis 服务端最大客户端连接数(可选;主要用于 deploy/docker-compose-aicodex.yml)
|
# Redis 服务端最大客户端连接数(可选)
|
||||||
REDIS_MAXCLIENTS=50000
|
REDIS_MAXCLIENTS=50000
|
||||||
# Redis 连接池大小(默认 1024)
|
# Redis 连接池大小(默认 1024)
|
||||||
REDIS_POOL_SIZE=4096
|
REDIS_POOL_SIZE=4096
|
||||||
|
|||||||
@@ -209,8 +209,9 @@ gateway:
|
|||||||
openai_ws:
|
openai_ws:
|
||||||
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
|
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
|
||||||
mode_router_v2_enabled: false
|
mode_router_v2_enabled: false
|
||||||
# ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效)
|
# ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效)
|
||||||
ingress_mode_default: shared
|
# 兼容旧值:shared/dedicated 会按 ctx_pool 处理。
|
||||||
|
ingress_mode_default: ctx_pool
|
||||||
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
|
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
|
||||||
enabled: true
|
enabled: true
|
||||||
# 按账号类型细分开关
|
# 按账号类型细分开关
|
||||||
|
|||||||
@@ -1,212 +0,0 @@
|
|||||||
# =============================================================================
|
|
||||||
# Sub2API Docker Compose Test Configuration (Local Build)
|
|
||||||
# =============================================================================
|
|
||||||
# Quick Start:
|
|
||||||
# 1. Copy .env.example to .env and configure
|
|
||||||
# 2. docker-compose -f docker-compose-test.yml up -d --build
|
|
||||||
# 3. Check logs: docker-compose -f docker-compose-test.yml logs -f sub2api
|
|
||||||
# 4. Access: http://localhost:8080
|
|
||||||
#
|
|
||||||
# This configuration builds the image from source (Dockerfile in project root).
|
|
||||||
# All configuration is done via environment variables.
|
|
||||||
# No Setup Wizard needed - the system auto-initializes on first run.
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
services:
|
|
||||||
# ===========================================================================
|
|
||||||
# Sub2API Application
|
|
||||||
# ===========================================================================
|
|
||||||
sub2api:
|
|
||||||
image: sub2api:latest
|
|
||||||
build:
|
|
||||||
context: ..
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
container_name: sub2api
|
|
||||||
restart: unless-stopped
|
|
||||||
ulimits:
|
|
||||||
nofile:
|
|
||||||
soft: 100000
|
|
||||||
hard: 100000
|
|
||||||
ports:
|
|
||||||
- "${BIND_HOST:-0.0.0.0}:${SERVER_PORT:-8080}:8080"
|
|
||||||
volumes:
|
|
||||||
# Data persistence (config.yaml will be auto-generated here)
|
|
||||||
- sub2api_data:/app/data
|
|
||||||
# Mount custom config.yaml (optional, overrides auto-generated config)
|
|
||||||
# - ./config.yaml:/app/data/config.yaml:ro
|
|
||||||
environment:
|
|
||||||
# =======================================================================
|
|
||||||
# Auto Setup (REQUIRED for Docker deployment)
|
|
||||||
# =======================================================================
|
|
||||||
- AUTO_SETUP=true
|
|
||||||
|
|
||||||
# =======================================================================
|
|
||||||
# Server Configuration
|
|
||||||
# =======================================================================
|
|
||||||
- SERVER_HOST=0.0.0.0
|
|
||||||
- SERVER_PORT=8080
|
|
||||||
- SERVER_MODE=${SERVER_MODE:-release}
|
|
||||||
- RUN_MODE=${RUN_MODE:-standard}
|
|
||||||
|
|
||||||
# =======================================================================
|
|
||||||
# Database Configuration (PostgreSQL)
|
|
||||||
# =======================================================================
|
|
||||||
- DATABASE_HOST=postgres
|
|
||||||
- DATABASE_PORT=5432
|
|
||||||
- DATABASE_USER=${POSTGRES_USER:-sub2api}
|
|
||||||
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
|
||||||
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
|
|
||||||
- DATABASE_SSLMODE=disable
|
|
||||||
- DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
|
|
||||||
- DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
|
|
||||||
- DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
|
|
||||||
- DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
|
|
||||||
|
|
||||||
# =======================================================================
|
|
||||||
# Redis Configuration
|
|
||||||
# =======================================================================
|
|
||||||
- REDIS_HOST=redis
|
|
||||||
- REDIS_PORT=6379
|
|
||||||
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
|
|
||||||
- REDIS_DB=${REDIS_DB:-0}
|
|
||||||
- REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
|
|
||||||
- REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
|
|
||||||
|
|
||||||
# =======================================================================
|
|
||||||
# Admin Account (auto-created on first run)
|
|
||||||
# =======================================================================
|
|
||||||
- ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local}
|
|
||||||
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-}
|
|
||||||
|
|
||||||
# =======================================================================
|
|
||||||
# JWT Configuration
|
|
||||||
# =======================================================================
|
|
||||||
# Leave empty to auto-generate (recommended)
|
|
||||||
- JWT_SECRET=${JWT_SECRET:-}
|
|
||||||
- JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24}
|
|
||||||
|
|
||||||
# =======================================================================
|
|
||||||
# Timezone Configuration
|
|
||||||
# This affects ALL time operations in the application:
|
|
||||||
# - Database timestamps
|
|
||||||
# - Usage statistics "today" boundary
|
|
||||||
# - Subscription expiry times
|
|
||||||
# - Log timestamps
|
|
||||||
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
|
|
||||||
# =======================================================================
|
|
||||||
- TZ=${TZ:-Asia/Shanghai}
|
|
||||||
|
|
||||||
# =======================================================================
|
|
||||||
# Gemini OAuth Configuration (for Gemini accounts)
|
|
||||||
# =======================================================================
|
|
||||||
- GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-}
|
|
||||||
- GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-}
|
|
||||||
- GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-}
|
|
||||||
- GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
|
|
||||||
|
|
||||||
# Built-in OAuth client secrets (optional)
|
|
||||||
# SECURITY: This repo does not embed third-party client_secret.
|
|
||||||
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
|
||||||
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
|
||||||
|
|
||||||
# =======================================================================
|
|
||||||
# Security Configuration (URL Allowlist)
|
|
||||||
# =======================================================================
|
|
||||||
# Allow private IP addresses for CRS sync (for internal deployments)
|
|
||||||
- SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=${SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS:-true}
|
|
||||||
depends_on:
|
|
||||||
postgres:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
networks:
|
|
||||||
- sub2api-network
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
start_period: 30s
|
|
||||||
|
|
||||||
# ===========================================================================
|
|
||||||
# PostgreSQL Database
|
|
||||||
# ===========================================================================
|
|
||||||
postgres:
|
|
||||||
image: postgres:18-alpine
|
|
||||||
container_name: sub2api-postgres
|
|
||||||
restart: unless-stopped
|
|
||||||
ulimits:
|
|
||||||
nofile:
|
|
||||||
soft: 100000
|
|
||||||
hard: 100000
|
|
||||||
volumes:
|
|
||||||
- postgres_data:/var/lib/postgresql/data
|
|
||||||
environment:
|
|
||||||
# postgres:18-alpine 默认 PGDATA=/var/lib/postgresql/18/docker(位于镜像声明的匿名卷 /var/lib/postgresql 内)。
|
|
||||||
# 若不显式设置 PGDATA,则即使挂载了 postgres_data 到 /var/lib/postgresql/data,数据也不会落盘到该命名卷,
|
|
||||||
# docker compose down/up 后会触发 initdb 重新初始化,导致用户/密码等数据丢失。
|
|
||||||
- PGDATA=/var/lib/postgresql/data
|
|
||||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
|
||||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
|
||||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
|
||||||
- TZ=${TZ:-Asia/Shanghai}
|
|
||||||
networks:
|
|
||||||
- sub2api-network
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"]
|
|
||||||
interval: 10s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 5
|
|
||||||
start_period: 10s
|
|
||||||
# 注意:不暴露端口到宿主机,应用通过内部网络连接
|
|
||||||
# 如需调试,可临时添加:ports: ["127.0.0.1:5433:5432"]
|
|
||||||
|
|
||||||
# ===========================================================================
|
|
||||||
# Redis Cache
|
|
||||||
# ===========================================================================
|
|
||||||
redis:
|
|
||||||
image: redis:8-alpine
|
|
||||||
container_name: sub2api-redis
|
|
||||||
restart: unless-stopped
|
|
||||||
ulimits:
|
|
||||||
nofile:
|
|
||||||
soft: 100000
|
|
||||||
hard: 100000
|
|
||||||
volumes:
|
|
||||||
- redis_data:/data
|
|
||||||
command: >
|
|
||||||
redis-server
|
|
||||||
--save 60 1
|
|
||||||
--appendonly yes
|
|
||||||
--appendfsync everysec
|
|
||||||
${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}}
|
|
||||||
environment:
|
|
||||||
- TZ=${TZ:-Asia/Shanghai}
|
|
||||||
# REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag)
|
|
||||||
- REDISCLI_AUTH=${REDIS_PASSWORD:-}
|
|
||||||
networks:
|
|
||||||
- sub2api-network
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
|
||||||
interval: 10s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 5
|
|
||||||
start_period: 5s
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Volumes
|
|
||||||
# =============================================================================
|
|
||||||
volumes:
|
|
||||||
sub2api_data:
|
|
||||||
driver: local
|
|
||||||
postgres_data:
|
|
||||||
driver: local
|
|
||||||
redis_data:
|
|
||||||
driver: local
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Networks
|
|
||||||
# =============================================================================
|
|
||||||
networks:
|
|
||||||
sub2api-network:
|
|
||||||
driver: bridge
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user