Compare commits

..

30 Commits

Author SHA1 Message Date
shaw
bab4bb9904 chore: 更新openai、claude使用秘钥教程部分 2026-03-05 18:58:10 +08:00
shaw
33bae6f49b fix: Cache Token拆分为缓存创建和缓存读取 2026-03-05 18:32:17 +08:00
Wesley Liddick
32d619a56b Merge pull request #780 from mt21625457/feat/codex-remote-compact-outcome-logging
feat(openai-handler): support codex remote compact outcome logging
2026-03-05 16:59:02 +08:00
Wesley Liddick
642432cf2a Merge pull request #777 from guoyongchang/feature-schedule-test-support
feat: 支持基于 crontab 的定时账号测试
2026-03-05 16:57:23 +08:00
程序猿MT
61e9598b08 fix(lint): remove redundant context type in compact outcome logger 2026-03-05 16:51:46 +08:00
guoyongchang
d4e34c7514 fix: 修复空结果导致定时测试模态框崩溃的问题
后端返回 null (Go nil slice) 时前端访问 .length 抛出 TypeError,
在 API 层对 listByAccount 和 listResults 加 ?? [] 兜底。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:47:01 +08:00
程序猿MT
bfe7a5e452 test(openai-handler): add codex remote compact outcome coverage 2026-03-05 16:46:14 +08:00
程序猿MT
77d916ffec feat(openai-handler): support codex remote compact outcome logging 2026-03-05 16:46:12 +08:00
guoyongchang
831abf7977 refactor: 移除冗余中间类型和不必要代码
- 移除 ScheduledTestOutcome 中间类型,RunTestBackground 直接返回 *ScheduledTestResult
- 简化 SaveResult 直接接受 *ScheduledTestResult
- 移除 handler 中不必要的 nil 检查
- 移除前端 ScheduledTestsPanel 中多余的 String() 转换

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:37:07 +08:00
guoyongchang
817a491087 simplify: 移除 leader lock,单实例无需分布式锁
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:31:27 +08:00
guoyongchang
9a8dacc514 fix: 修复 golangci-lint depguard 和 gofmt 错误
将 redis leader lock 逻辑从 service 层抽取为 LeaderLocker 接口,
实现移至 repository 层,消除 service 层对 redis 的直接依赖。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:28:48 +08:00
guoyongchang
8adf80d98b fix: wire_gen_test 补充 scheduledTestRunner 参数
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:23:41 +08:00
guoyongchang
62686a6213 revert: 还原 docker-compose.local.yml 的本地测试改动
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:17:33 +08:00
guoyongchang
3a089242f8 feat: 支持基于 crontab 的定时账号测试
每个测试计划绑定一个账号和一个模型,按 cron 表达式定期执行测试,
保存历史结果并在前端账号管理页面中提供完整的增删改查和结果查看功能。

主要变更:
- 新增 scheduled_test_plans / scheduled_test_results 两张表及迁移
- 后端 service 层:CRUD 服务 + 后台 cron runner(每分钟扫描到期计划并发执行)
- RunTestBackground 方法通过 httptest 在内存中执行账号测试并解析 SSE 输出
- Redis leader lock + pg_try_advisory_lock 双重保障多实例部署只执行一次
- REST API:5 个管理端点(计划 CRUD + 结果查询)
- 前端 ScheduledTestsPanel 组件:计划管理、启用开关、内联编辑、结果展开查看
- 中英文 i18n 支持

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 16:06:05 +08:00
shaw
9d70c38504 fix: 修复claude apikey账号请求时未携带beta=true 查询参数的bug 2026-03-05 15:01:04 +08:00
shaw
aeb464f3ca feat: 模型映射应用 /v1/messages/count_tokens端点 2026-03-05 14:49:28 +08:00
Wesley Liddick
7076717b20 Merge pull request #772 from mt21625457/aicodex2api-main
feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
2026-03-05 13:46:02 +08:00
程序猿MT
c0a4fcea0a Delete docker-compose-aicodex.yml
删除测试 docker compose文件
2026-03-05 13:44:07 +08:00
程序猿MT
aa2b195c86 Delete Caddyfile.dmit
删除测试caddy 配置文件
2026-03-05 13:43:25 +08:00
yangjianbo
1d0872e7ca feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
新增 OpenAI WebSocket v2 passthrough relay 数据面与服务适配层,
支持按账号 ws mode 在 ctx_pool 与 passthrough 间路由。

同步调整前端 OpenAI ws mode 选项为 off/ctx_pool/passthrough,
并补充 i18n 文案与对应单测。

新增 Caddyfile.dmit 与 docker-compose-aicodex.yml 部署配置,
用于宿主机场景下的反向代理与服务编排。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 11:50:58 +08:00
shaw
33988637b5 fix: SMTP测试连接和发送测试邮件返回具体错误信息而非internal error 2026-03-05 10:54:41 +08:00
shaw
d4f6ad7225 feat: 新增apikey的usage查询页面 2026-03-05 10:45:51 +08:00
shaw
078fefed03 fix: 修复账号管理页面容量列显示为0的bug 2026-03-05 09:48:00 +08:00
Wesley Liddick
5b10af85b4 Merge pull request #762 from touwaeriol/fix/dark-theme-open-in-new-tab
fix: add dark theme support for "open in new tab" FAB button
2026-03-05 08:56:28 +08:00
Wesley Liddick
4caf95e5dd Merge pull request #767 from litianc/fix/rewrite-userid-regex-match-account-uuid
fix: extend RewriteUserID regex to match user_id containing account_uuid
2026-03-05 08:56:03 +08:00
litianc
8e1bcf53bb fix: extend RewriteUserID regex to match user_id containing account_uuid
The existing regex only matched the old format where account_uuid is
empty (account__session_). Real Claude Code clients and newer sub2api
generated user_ids use account_{uuid}_session_ which was silently
skipped, causing the original metadata.user_id to leak to upstream
when User-Agent is rewritten by an intermediate gateway.

Closes #766

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-04 23:13:17 +08:00
erio
064f9be7e4 fix: add dark theme support for "open in new tab" FAB button
The backdrop-blur background on the iframe "open in new tab" floating
button was hardcoded to bg-white/80, making it look broken in dark
theme. Added dark:bg-dark-800/80 variant for both PurchaseSubscription
and CustomPage views.
2026-03-04 21:40:40 +08:00
Wesley Liddick
adcfb44cb7 Merge pull request #761 from james-6-23/main
feat: 修复 v0.1.89 OAuth 401 永久锁死账号问题,改用临时不可调度实现自动恢复;增强二次 401 自动升级为错误状态,添加 DB   回退确保生效;管理后台新增临时不可调度状态筛选
2026-03-04 21:11:24 +08:00
kyx236
3d79773ba2 Merge branch 'main' of https://github.com/james-6-23/sub2api 2026-03-04 20:25:39 +08:00
kyx236
6aa8cbbf20 feat: 二次 401 直接升级为错误状态,添加 DB 回退确保生效
账号首次 401 仅临时不可调度,给予 token 刷新窗口;若恢复后再次 401
说明凭证确实失效,直接升级为错误状态以避免反复无效调度。

- 缓存中 reason 为空时从 DB 回退读取,防止升级判断失效
- ClearError 同时清除临时不可调度状态,管理员恢复后重新给予一次机会
- 管理后台账号列表添加"临时不可调度"状态筛选
- 补充 DB 回退场景单元测试
2026-03-04 20:25:15 +08:00
70 changed files with 6418 additions and 219 deletions

View File

@@ -86,6 +86,7 @@ func provideCleanup(
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -216,6 +217,12 @@ func provideCleanup(
}
return nil
}},
{"ScheduledTestRunnerService", func() error {
if scheduledTestRunner != nil {
scheduledTestRunner.Stop()
}
return nil
}},
}
infraSteps := []cleanupStep{

View File

@@ -195,7 +195,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db)
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -225,7 +229,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -273,6 +278,7 @@ func provideCleanup(
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -402,6 +408,12 @@ func provideCleanup(
}
return nil
}},
{"ScheduledTestRunnerService", func() error {
if scheduledTestRunner != nil {
scheduledTestRunner.Stop()
}
return nil
}},
}
infraSteps := []cleanupStep{

View File

@@ -74,6 +74,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
geminiOAuthSvc,
antigravityOAuthSvc,
nil, // openAIGateway
nil, // scheduledTestRunner
)
require.NotPanics(t, func() {

View File

@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
type GatewayOpenAIWSConfig struct {
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false关闭时保持 legacy 行为)
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
// IngressModeDefault: ingress 默认模式off/shared/dedicated
// IngressModeDefault: ingress 默认模式off/ctx_pool/passthrough
IngressModeDefault string `mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true
Enabled bool `mapstructure:"enabled"`
@@ -1335,7 +1335,7 @@ func setDefaults() {
// OpenAI Responses WebSocket默认开启可通过 force_http 紧急回滚)
viper.SetDefault("gateway.openai_ws.enabled", true)
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
viper.SetDefault("gateway.openai_ws.force_http", false)
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
switch mode {
case "off", "shared", "dedicated":
case "off", "ctx_pool", "passthrough":
case "shared", "dedicated":
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
default:
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
}
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {

View File

@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
}
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
}
}
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
},
{
name: "ingress_mode_default 必须为 off|shared|dedicated",
name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
wantErr: "gateway.openai_ws.ingress_mode_default",
},

View File

@@ -240,77 +240,77 @@ func (h *AccountHandler) List(c *gin.Context) {
var windowCosts map[int64]float64
var activeSessions map[int64]int
var rpmCounts map[int64]int
if !lite {
// Get current concurrency counts for all accounts
if h.concurrencyService != nil {
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
concurrencyCounts = cc
// 始终获取并发数Redis ZCARD极低开销
if h.concurrencyService != nil {
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
concurrencyCounts = cc
}
}
// 识别需要查询窗口费用、会话数和 RPM 的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
rpmAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
if acc.IsAnthropicOAuthOrSetupToken() {
if acc.GetWindowCostLimit() > 0 {
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
}
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
if acc.GetBaseRPM() > 0 {
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
}
}
// 识别需要查询窗口费用、会话数和 RPM 的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
rpmAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
}
// 始终获取 RPM 计数Redis GET极低开销
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
if rpmCounts == nil {
rpmCounts = make(map[int64]int)
}
}
// 始终获取活跃会话数Redis ZCARD低开销
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil {
activeSessions = make(map[int64]int)
}
}
// 仅非 lite 模式获取窗口费用PostgreSQL 聚合查询,高开销)
if !lite && len(windowCostAccountIDs) > 0 {
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
for i := range accounts {
acc := &accounts[i]
if acc.IsAnthropicOAuthOrSetupToken() {
if acc.GetWindowCostLimit() > 0 {
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
}
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
if acc.GetBaseRPM() > 0 {
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
}
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
}
}
// 获取 RPM 计数(批量查询
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
if rpmCounts == nil {
rpmCounts = make(map[int64]int)
}
}
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil {
activeSessions = make(map[int64]int)
}
}
// 获取窗口费用(并行查询)
if len(windowCostAccountIDs) > 0 {
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
for i := range accounts {
acc := &accounts[i]
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
accCopy := acc // 闭包捕获
g.Go(func() error {
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况
startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
}
accCopy := acc // 闭包捕获
g.Go(func() error {
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
}
return nil // 不返回错误,允许部分失败
})
}
_ = g.Wait()
return nil // 不返回错误,允许部分失败
})
}
_ = g.Wait()
}
// Build response with concurrency info

View File

@@ -0,0 +1,155 @@
package admin
import (
"net/http"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ScheduledTestHandler handles admin scheduled-test-plan management.
type ScheduledTestHandler struct {
scheduledTestSvc *service.ScheduledTestService
}
// NewScheduledTestHandler creates a new ScheduledTestHandler.
func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler {
return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc}
}
type createScheduledTestPlanRequest struct {
AccountID int64 `json:"account_id" binding:"required"`
ModelID string `json:"model_id"`
CronExpression string `json:"cron_expression" binding:"required"`
Enabled *bool `json:"enabled"`
MaxResults int `json:"max_results"`
}
type updateScheduledTestPlanRequest struct {
ModelID string `json:"model_id"`
CronExpression string `json:"cron_expression"`
Enabled *bool `json:"enabled"`
MaxResults int `json:"max_results"`
}
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid account id")
return
}
plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, plans)
}
// Create POST /admin/scheduled-test-plans
func (h *ScheduledTestHandler) Create(c *gin.Context) {
var req createScheduledTestPlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
plan := &service.ScheduledTestPlan{
AccountID: req.AccountID,
ModelID: req.ModelID,
CronExpression: req.CronExpression,
Enabled: true,
MaxResults: req.MaxResults,
}
if req.Enabled != nil {
plan.Enabled = *req.Enabled
}
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
if err != nil {
response.BadRequest(c, err.Error())
return
}
c.JSON(http.StatusOK, created)
}
// Update PUT /admin/scheduled-test-plans/:id
func (h *ScheduledTestHandler) Update(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID)
if err != nil {
response.NotFound(c, "plan not found")
return
}
var req updateScheduledTestPlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if req.ModelID != "" {
existing.ModelID = req.ModelID
}
if req.CronExpression != "" {
existing.CronExpression = req.CronExpression
}
if req.Enabled != nil {
existing.Enabled = *req.Enabled
}
if req.MaxResults > 0 {
existing.MaxResults = req.MaxResults
}
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
if err != nil {
response.BadRequest(c, err.Error())
return
}
c.JSON(http.StatusOK, updated)
}
// Delete DELETE /admin/scheduled-test-plans/:id
func (h *ScheduledTestHandler) Delete(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
}
// ListResults GET /admin/scheduled-test-plans/:id/results
func (h *ScheduledTestHandler) ListResults(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
limit := 50
if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 {
limit = l
}
results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit)
if err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, results)
}

View File

@@ -819,7 +819,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
err := h.emailService.TestSMTPConnectionWithConfig(config)
if err != nil {
response.ErrorFrom(c, err)
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
return
}
@@ -905,7 +905,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
`
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
response.ErrorFrom(c, err)
response.BadRequest(c, "Failed to send test email: "+err.Error())
return
}

View File

@@ -27,6 +27,7 @@ type AdminHandlers struct {
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
APIKey *admin.AdminAPIKeyHandler
ScheduledTest *admin.ScheduledTestHandler
}
// Handlers contains all HTTP handlers

View File

@@ -0,0 +1,192 @@
package handler
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
var handlerStructuredLogCaptureMu sync.Mutex
type handlerInMemoryLogSink struct {
mu sync.Mutex
events []*logger.LogEvent
}
func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
if event == nil {
return
}
cloned := *event
if event.Fields != nil {
cloned.Fields = make(map[string]any, len(event.Fields))
for k, v := range event.Fields {
cloned.Fields[k] = v
}
}
s.mu.Lock()
s.events = append(s.events, &cloned)
s.mu.Unlock()
}
func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool {
s.mu.Lock()
defer s.mu.Unlock()
wantLevel := strings.ToLower(strings.TrimSpace(level))
for _, ev := range s.events {
if ev == nil {
continue
}
if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel {
return true
}
}
return false
}
func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool {
s.mu.Lock()
defer s.mu.Unlock()
for _, ev := range s.events {
if ev == nil || ev.Fields == nil {
continue
}
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
return true
}
}
return false
}
func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) {
t.Helper()
handlerStructuredLogCaptureMu.Lock()
err := logger.Init(logger.InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: logger.SamplingOptions{Enabled: false},
})
require.NoError(t, err)
sink := &handlerInMemoryLogSink{}
logger.SetSink(sink)
return sink, func() {
logger.SetSink(nil)
handlerStructuredLogCaptureMu.Unlock()
}
}
func TestIsOpenAIRemoteCompactPath(t *testing.T) {
require.False(t, isOpenAIRemoteCompactPath(nil))
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
require.True(t, isOpenAIRemoteCompactPath(c))
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil)
require.True(t, isOpenAIRemoteCompactPath(c))
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
require.False(t, isOpenAIRemoteCompactPath(c))
}
func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Set(opsModelKey, "gpt-5.3-codex")
c.Set(opsAccountIDKey, int64(123))
c.Header("x-request-id", "rid-compact-ok")
c.Status(http.StatusOK)
h := &OpenAIGatewayHandler{}
h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond))
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded"))
require.True(t, logSink.ContainsFieldValue("status_code", "200"))
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex"))
require.True(t, logSink.ContainsFieldValue("account_id", "123"))
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok"))
}
func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Status(http.StatusBadGateway)
h := &OpenAIGatewayHandler{}
h.logOpenAIRemoteCompactOutcome(c, time.Now())
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed"))
require.True(t, logSink.ContainsFieldValue("status_code", "502"))
require.True(t, logSink.ContainsFieldValue("path", "/responses/compact"))
}
func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
c.Status(http.StatusOK)
h := &OpenAIGatewayHandler{}
h.logOpenAIRemoteCompactOutcome(c, time.Now())
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
}
func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureHandlerStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
h := &OpenAIGatewayHandler{}
h.Responses(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
require.True(t, logSink.ContainsFieldValue("status_code", "401"))
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
}

View File

@@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
cfg *config.Config
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -61,6 +62,7 @@ func NewOpenAIGatewayHandler(
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
}
}
@@ -70,6 +72,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
streamStarted := false
defer h.recoverResponsesPanic(c, &streamStarted)
compactStartedAt := time.Now()
defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt)
setOpenAIClientTransportHTTP(c)
requestStart := time.Now()
@@ -340,6 +344,86 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
}
func isOpenAIRemoteCompactPath(c *gin.Context) bool {
if c == nil || c.Request == nil || c.Request.URL == nil {
return false
}
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
return strings.HasSuffix(normalizedPath, "/responses/compact")
}
func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) {
if !isOpenAIRemoteCompactPath(c) {
return
}
var (
ctx = context.Background()
path string
status int
)
if c != nil {
if c.Request != nil {
ctx = c.Request.Context()
if c.Request.URL != nil {
path = strings.TrimSpace(c.Request.URL.Path)
}
}
if c.Writer != nil {
status = c.Writer.Status()
}
}
outcome := "failed"
if status >= 200 && status < 300 {
outcome = "succeeded"
}
latencyMs := time.Since(startedAt).Milliseconds()
if latencyMs < 0 {
latencyMs = 0
}
fields := []zap.Field{
zap.String("component", "handler.openai_gateway.responses"),
zap.Bool("remote_compact", true),
zap.String("compact_outcome", outcome),
zap.Int("status_code", status),
zap.Int64("latency_ms", latencyMs),
zap.String("path", path),
zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI),
}
if c != nil {
if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" {
fields = append(fields, zap.String("request_user_agent", userAgent))
}
if v, ok := c.Get(opsModelKey); ok {
if model, ok := v.(string); ok && strings.TrimSpace(model) != "" {
fields = append(fields, zap.String("request_model", strings.TrimSpace(model)))
}
}
if v, ok := c.Get(opsAccountIDKey); ok {
if accountID, ok := v.(int64); ok && accountID > 0 {
fields = append(fields, zap.Int64("account_id", accountID))
}
}
if c.Writer != nil {
if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" {
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
} else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" {
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
}
}
}
log := logger.FromContext(ctx).With(fields...)
if outcome == "succeeded" {
log.Info("codex.remote_compact.succeeded")
return
}
log.Warn("codex.remote_compact.failed")
}
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
return true

View File

@@ -30,6 +30,7 @@ func ProvideAdminHandlers(
userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
apiKeyHandler *admin.AdminAPIKeyHandler,
scheduledTestHandler *admin.ScheduledTestHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -53,6 +54,7 @@ func ProvideAdminHandlers(
UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
APIKey: apiKeyHandler,
ScheduledTest: scheduledTestHandler,
}
}
@@ -141,6 +143,7 @@ var ProviderSet = wire.NewSet(
admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
admin.NewAdminAPIKeyHandler,
admin.NewScheduledTestHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,

View File

@@ -57,25 +57,28 @@ type DashboardStats struct {
// TrendDataPoint represents a single point in trend data
type TrendDataPoint struct {
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheCreationTokens int64 `json:"cache_creation_tokens"`
CacheReadTokens int64 `json:"cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ModelStat represents usage statistics for a single model
type ModelStat struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheCreationTokens int64 `json:"cache_creation_tokens"`
CacheReadTokens int64 `json:"cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GroupStat represents usage statistics for a single group

View File

@@ -437,6 +437,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
switch status {
case "rate_limited":
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
case "temp_unschedulable":
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.And(
entsql.Not(entsql.IsNull(col)),
entsql.GT(col, entsql.Expr("NOW()")),
))
}))
default:
q = q.Where(dbaccount.StatusEQ(status))
}
@@ -640,7 +648,17 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
SetStatus(service.StatusActive).
SetErrorMessage("").
Save(ctx)
return err
if err != nil {
return err
}
// 清除临时不可调度状态,重置 401 升级链
_, _ = r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = NULL,
temp_unschedulable_reason = NULL
WHERE id = $1 AND deleted_at IS NULL
`, id)
return nil
}
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {

View File

@@ -0,0 +1,183 @@
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// --- Plan Repository ---
type scheduledTestPlanRepository struct {
db *sql.DB
}
func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanRepository {
return &scheduledTestPlanRepository{db: db}
}
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
return scanPlan(row)
}
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans WHERE id = $1
`, id)
return scanPlan(row)
}
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans WHERE account_id = $1
ORDER BY created_at DESC
`, accountID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
return scanPlans(rows)
}
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans
WHERE enabled = true AND next_run_at <= $1
ORDER BY next_run_at ASC
`, now)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
return scanPlans(rows)
}
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
UPDATE scheduled_test_plans
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
WHERE id = $1
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
return scanPlan(row)
}
func (r *scheduledTestPlanRepository) Delete(ctx context.Context, id int64) error {
_, err := r.db.ExecContext(ctx, `DELETE FROM scheduled_test_plans WHERE id = $1`, id)
return err
}
func (r *scheduledTestPlanRepository) UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error {
_, err := r.db.ExecContext(ctx, `
UPDATE scheduled_test_plans SET last_run_at = $2, next_run_at = $3, updated_at = NOW() WHERE id = $1
`, id, lastRunAt, nextRunAt)
return err
}
// --- Result Repository ---
type scheduledTestResultRepository struct {
db *sql.DB
}
func NewScheduledTestResultRepository(db *sql.DB) service.ScheduledTestResultRepository {
return &scheduledTestResultRepository{db: db}
}
func (r *scheduledTestResultRepository) Create(ctx context.Context, result *service.ScheduledTestResult) (*service.ScheduledTestResult, error) {
row := r.db.QueryRowContext(ctx, `
INSERT INTO scheduled_test_results (plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
RETURNING id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
`, result.PlanID, result.Status, result.ResponseText, result.ErrorMessage, result.LatencyMs, result.StartedAt, result.FinishedAt)
out := &service.ScheduledTestResult{}
if err := row.Scan(
&out.ID, &out.PlanID, &out.Status, &out.ResponseText, &out.ErrorMessage,
&out.LatencyMs, &out.StartedAt, &out.FinishedAt, &out.CreatedAt,
); err != nil {
return nil, err
}
return out, nil
}
func (r *scheduledTestResultRepository) ListByPlanID(ctx context.Context, planID int64, limit int) ([]*service.ScheduledTestResult, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
FROM scheduled_test_results
WHERE plan_id = $1
ORDER BY created_at DESC
LIMIT $2
`, planID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var results []*service.ScheduledTestResult
for rows.Next() {
r := &service.ScheduledTestResult{}
if err := rows.Scan(
&r.ID, &r.PlanID, &r.Status, &r.ResponseText, &r.ErrorMessage,
&r.LatencyMs, &r.StartedAt, &r.FinishedAt, &r.CreatedAt,
); err != nil {
return nil, err
}
results = append(results, r)
}
return results, rows.Err()
}
func (r *scheduledTestResultRepository) PruneOldResults(ctx context.Context, planID int64, keepCount int) error {
_, err := r.db.ExecContext(ctx, `
DELETE FROM scheduled_test_results
WHERE id IN (
SELECT id FROM (
SELECT id, ROW_NUMBER() OVER (PARTITION BY plan_id ORDER BY created_at DESC) AS rn
FROM scheduled_test_results
WHERE plan_id = $1
) ranked
WHERE rn > $2
)
`, planID, keepCount)
return err
}
// --- scan helpers ---
type scannable interface {
Scan(dest ...any) error
}
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
p := &service.ScheduledTestPlan{}
if err := row.Scan(
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, err
}
return p, nil
}
func scanPlans(rows *sql.Rows) ([]*service.ScheduledTestPlan, error) {
var plans []*service.ScheduledTestPlan
for rows.Next() {
p, err := scanPlan(rows)
if err != nil {
return nil, err
}
plans = append(plans, p)
}
return plans, rows.Err()
}

View File

@@ -1363,7 +1363,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
@@ -1401,6 +1402,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
@@ -1664,7 +1667,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
@@ -1747,7 +1751,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
total_requests as requests,
input_tokens,
output_tokens,
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
cache_creation_tokens,
cache_read_tokens,
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
total_cost as cost,
actual_cost
@@ -1762,7 +1767,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
total_requests as requests,
input_tokens,
output_tokens,
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
cache_creation_tokens,
cache_read_tokens,
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
total_cost as cost,
actual_cost
@@ -1806,6 +1812,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
@@ -2622,7 +2630,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
&row.Requests,
&row.InputTokens,
&row.OutputTokens,
&row.CacheTokens,
&row.CacheCreationTokens,
&row.CacheReadTokens,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
@@ -2646,6 +2655,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
&row.Requests,
&row.InputTokens,
&row.OutputTokens,
&row.CacheCreationTokens,
&row.CacheReadTokens,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,

View File

@@ -125,7 +125,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
require.NoError(t, err)
@@ -144,7 +144,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
require.NoError(t, err)

View File

@@ -53,7 +53,9 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewSoraAccountRepository, // Sora 账号扩展表仓储
NewSoraAccountRepository, // Sora 账号扩展表仓储
NewScheduledTestPlanRepository, // 定时测试计划仓储
NewScheduledTestResultRepository, // 定时测试结果仓储
NewProxyRepository,
NewRedeemCodeRepository,
NewPromoCodeRepository,

View File

@@ -78,6 +78,9 @@ func RegisterAdminRoutes(
// API Key 管理
registerAdminAPIKeyRoutes(admin, h)
// 定时测试计划
registerScheduledTestRoutes(admin, h)
}
}
@@ -478,6 +481,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
plans := admin.Group("/scheduled-test-plans")
{
plans.POST("", h.Admin.ScheduledTest.Create)
plans.PUT("/:id", h.Admin.ScheduledTest.Update)
plans.DELETE("/:id", h.Admin.ScheduledTest.Delete)
plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults)
}
// Nested under accounts
admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount)
}
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
rules := admin.Group("/error-passthrough-rules")
{

View File

@@ -853,15 +853,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
}
const (
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
OpenAIWSIngressModeCtxPool = "ctx_pool"
OpenAIWSIngressModePassthrough = "passthrough"
)
func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff
case OpenAIWSIngressModeCtxPool:
return OpenAIWSIngressModeCtxPool
case OpenAIWSIngressModePassthrough:
return OpenAIWSIngressModePassthrough
case OpenAIWSIngressModeShared:
return OpenAIWSIngressModeShared
case OpenAIWSIngressModeDedicated:
@@ -873,18 +879,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
return OpenAIWSIngressModeCtxPool
}
return normalized
}
return OpenAIWSIngressModeShared
return OpenAIWSIngressModeCtxPool
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式off/shared/dedicated)。
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式off/ctx_pool/passthrough)。
//
// 优先级:
// 1. 分类型 mode 新字段string
// 2. 分类型 enabled 旧字段bool
// 3. 兼容 enabled 旧字段bool
// 4. defaultMode非法时回退 shared
// 4. defaultMode非法时回退 ctx_pool
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() {
@@ -919,7 +928,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return "", false
}
if enabled {
return OpenAIWSIngressModeShared, true
return OpenAIWSIngressModeCtxPool, true
}
return OpenAIWSIngressModeOff, true
}
@@ -946,6 +955,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
return mode
}
// 兼容旧值shared/dedicated 语义都归并到 ctx_pool。
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
return OpenAIWSIngressModeCtxPool
}
return resolvedDefault
}

View File

@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
}
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
t.Run("default fallback to shared", func(t *testing.T) {
t.Run("default fallback to ctx_pool", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
})
t.Run("oauth mode field has highest priority", func(t *testing.T) {
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": false,
},
}
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
t.Run("legacy enabled maps to shared", func(t *testing.T) {
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
shared := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
},
}
dedicated := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
})
t.Run("legacy disabled maps to off", func(t *testing.T) {
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
t.Run("non openai always off", func(t *testing.T) {

View File

@@ -12,6 +12,7 @@ import (
"io"
"log"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
@@ -33,7 +34,7 @@ import (
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
@@ -238,7 +239,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages?beta=true"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -1560,3 +1561,62 @@ func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) er
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
return fmt.Errorf("%s", errorMsg)
}
// RunTestBackground executes an account test in-memory (no real HTTP client),
// capturing SSE output via httptest.NewRecorder, then parses the result.
func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID int64, modelID string) (*ScheduledTestResult, error) {
startedAt := time.Now()
w := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
finishedAt := time.Now()
body := w.Body.String()
responseText, errMsg := parseTestSSEOutput(body)
status := "success"
if testErr != nil || errMsg != "" {
status = "failed"
if errMsg == "" && testErr != nil {
errMsg = testErr.Error()
}
}
return &ScheduledTestResult{
Status: status,
ResponseText: responseText,
ErrorMessage: errMsg,
LatencyMs: finishedAt.Sub(startedAt).Milliseconds(),
StartedAt: startedAt,
FinishedAt: finishedAt,
}, nil
}
// parseTestSSEOutput extracts response text and error message from captured SSE output.
func parseTestSSEOutput(body string) (responseText, errMsg string) {
var texts []string
for _, line := range strings.Split(body, "\n") {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimPrefix(line, "data: ")
var event TestEvent
if err := json.Unmarshal([]byte(jsonStr), &event); err != nil {
continue
}
switch event.Type {
case "content":
if event.Text != "" {
texts = append(texts, event.Text)
}
case "error":
errMsg = event.Error
}
}
responseText = strings.Join(texts, "")
return
}

View File

@@ -88,6 +88,49 @@ func TestCheckErrorPolicy(t *testing.T) {
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_401_first_hit_returns_temp_unscheduled",
account: &Account{
ID: 14,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
account: &Account{
ID: 15,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyNone,
},
{
name: "temp_unschedulable_body_miss_returns_none",
account: &Account{

View File

@@ -171,8 +171,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体")
require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射")
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
@@ -190,7 +189,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
require.True(t, ok)
bodyBytes, ok := rawBody.([]byte)
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
require.Equal(t, body, bodyBytes)
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
@@ -253,8 +252,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体")
require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射")
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
@@ -263,6 +261,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
require.Empty(t, rec.Header().Get("Set-Cookie"))
}
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
model string
modelMapping map[string]any // nil = 不配置映射
expectedModel string
endpoint string // "messages" or "count_tokens"
}{
{
name: "Forward: 无映射配置时不改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: nil,
expectedModel: "claude-sonnet-4-20250514",
endpoint: "messages",
},
{
name: "Forward: 空映射配置时不改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{},
expectedModel: "claude-sonnet-4-20250514",
endpoint: "messages",
},
{
name: "Forward: 模型不在映射表中时不改写",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
expectedModel: "claude-sonnet-4-20250514",
endpoint: "messages",
},
{
name: "Forward: 精确匹配映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "messages",
},
{
name: "Forward: 通配符映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "messages",
},
{
name: "CountTokens: 无映射配置时不改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: nil,
expectedModel: "claude-sonnet-4-20250514",
endpoint: "count_tokens",
},
{
name: "CountTokens: 模型不在映射表中时不改写",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
expectedModel: "claude-sonnet-4-20250514",
endpoint: "count_tokens",
},
{
name: "CountTokens: 精确匹配映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "count_tokens",
},
{
name: "CountTokens: 通配符映射应改写模型",
model: "claude-sonnet-4-20250514",
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
expectedModel: "claude-sonnet-4-5-20241022",
endpoint: "count_tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
parsed := &ParsedRequest{
Body: body,
Model: tt.model,
}
credentials := map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
}
if tt.modelMapping != nil {
credentials["model_mapping"] = tt.modelMapping
}
account := &Account{
ID: 300,
Name: "edge-case-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: credentials,
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
if tt.endpoint == "messages" {
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
parsed.Stream = false
upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamJSON)),
},
}
svc := &GatewayService{
cfg: &config.Config{},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
result, err := svc.Forward(context.Background(), c, account, parsed)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
"Forward 上游请求体中的模型应为: %s", tt.expectedModel)
} else {
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
upstreamRespBody := `{"input_tokens":42}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
"CountTokens 上游请求体中的模型应为: %s", tt.expectedModel)
}
})
}
}
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields
// 确保模型映射只替换 model 字段,不影响请求体中的其他字段
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
// 包含复杂字段的请求体system、thinking、messages
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`)
parsed := &ParsedRequest{
Body: body,
Model: "claude-sonnet-4-20250514",
}
upstreamRespBody := `{"input_tokens":42}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 301,
Name: "preserve-fields-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
"model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
},
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
sentBody := upstream.lastBody
require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射")
require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改")
require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改")
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改")
require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改")
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
}
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
// 确保空模型名不会触发映射逻辑
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
parsed := &ParsedRequest{
Body: body,
Model: "", // 空模型
}
upstreamRespBody := `{"input_tokens":10}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 302,
Name: "empty-model-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
"model_mapping": map[string]any{"*": "claude-3-opus-20240229"},
},
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
// 空模型名时body 应原样透传,不应触发映射
require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -3889,7 +3889,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime)
passthroughBody := parsed.Body
passthroughModel := parsed.Model
if passthroughModel != "" {
if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel {
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
passthroughModel = mappedModel
}
}
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
}
body := parsed.Body
@@ -4574,7 +4583,7 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
targetURL = validatedURL + "/v1/messages?beta=true"
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
@@ -4954,7 +4963,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
targetURL = validatedURL + "/v1/messages?beta=true"
}
}
@@ -6781,7 +6790,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body)
passthroughBody := parsed.Body
if reqModel := parsed.Model; reqModel != "" {
if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel {
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
}
}
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody)
}
body := parsed.Body
@@ -7072,7 +7088,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
@@ -7119,7 +7135,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
}
}

View File

@@ -122,6 +122,28 @@ func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "gemini_apikey_temp_unschedulable_401_second_hit_returns_none",
account: &Account{
ID: 105,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyNone,
},
{
name: "gemini_custom_codes_override_temp_unschedulable",
account: &Account{

View File

@@ -19,8 +19,10 @@ import (
// 预编译正则表达式(避免每次调用重新编译)
var (
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`)
// 匹配 user_id 格式:
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
)
@@ -239,13 +241,16 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return body, nil
}
// 匹配格式: user_{64位hex}_account__session_{uuid}
// 匹配格式:
// 旧格式: user_{64位hex}_account__session_{uuid}
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
return body, nil
}
sessionTail := matches[1] // 原始session UUID
// matches[1] = account UUID (可能为空), matches[2] = session UUID
sessionTail := matches[2] // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)

View File

@@ -263,13 +263,15 @@ type OpenAIGatewayService struct {
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
openaiSchedulerOnce sync.Once
openaiWSPool *openAIWSConnPool
openaiWSStateStore OpenAIWSStateStore
openaiScheduler OpenAIAccountScheduler
openaiAccountStats *openAIAccountRuntimeStats
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
openaiSchedulerOnce sync.Once
openaiWSPassthroughDialerOnce sync.Once
openaiWSPool *openAIWSConnPool
openaiWSStateStore OpenAIWSStateStore
openaiScheduler OpenAIAccountScheduler
openaiWSPassthroughDialer openAIWSClientDialer
openaiAccountStats *openAIAccountRuntimeStats
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiWSRetryMetrics openAIWSRetryMetrics

View File

@@ -11,6 +11,7 @@ import (
"sync/atomic"
"time"
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws "github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
@@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct {
conn *coderws.Conn
}
var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil)
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
@@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
}
}
func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.conn == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
msgType, payload, err := c.conn.Read(ctx)
if err != nil {
return coderws.MessageText, nil, err
}
return msgType, payload, nil
}
func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Write(ctx, msgType, payload)
}
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed

View File

@@ -46,9 +46,10 @@ const (
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
openAIWSPayloadSizeEstimateMaxItems = 16
openAIWSEventFlushBatchSizeDefault = 4
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
openAIWSPayloadLogSampleDefault = 0.2
openAIWSEventFlushBatchSizeDefault = 4
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
openAIWSPayloadLogSampleDefault = 0.2
openAIWSPassthroughIdleTimeoutDefault = time.Hour
openAIWSStoreDisabledConnModeStrict = "strict"
openAIWSStoreDisabledConnModeAdaptive = "adaptive"
@@ -904,6 +905,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
return s.openaiWSPool
}
func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer {
if s == nil {
return nil
}
s.openaiWSPassthroughDialerOnce.Do(func() {
if s.openaiWSPassthroughDialer == nil {
s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer()
}
})
return s.openaiWSPassthroughDialer
}
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
pool := s.getOpenAIWSConnPool()
if pool == nil {
@@ -967,6 +980,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
return 15 * time.Minute
}
func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration {
if timeout := s.openAIWSReadTimeout(); timeout > 0 {
return timeout
}
return openAIWSPassthroughIdleTimeoutDefault
}
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
@@ -2322,7 +2342,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
ingressMode := OpenAIWSIngressModeShared
ingressMode := OpenAIWSIngressModeCtxPool
if modeRouterV2Enabled {
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
if ingressMode == OpenAIWSIngressModeOff {
@@ -2332,6 +2352,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
nil,
)
}
switch ingressMode {
case OpenAIWSIngressModePassthrough:
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
}
return s.proxyResponsesWebSocketV2Passthrough(
ctx,
c,
clientConn,
account,
token,
firstClientMessage,
hooks,
wsDecision,
)
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// continue
default:
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"websocket mode only supports ctx_pool/passthrough",
nil,
)
}
}
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)

View File

@@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
select {
case serverErr := <-serverErrCh:
@@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe
require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
upstreamConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPassthroughDialer: captureDialer,
}
account := &Account{
ID: 452,
Name: "openai-ingress-passthrough",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
},
}
serverErrCh := make(chan error, 1)
resultCh := make(chan *OpenAIForwardResult, 1)
hooks := &OpenAIWSIngressHooks{
AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
if turnErr == nil && result != nil {
resultCh <- result
}
},
}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
require.Equal(t, "resp_passthrough_turn_1", gjson.GetBytes(event, "response.id").String())
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 passthrough websocket 结束超时")
}
select {
case result := <-resultCh:
require.Equal(t, "resp_passthrough_turn_1", result.RequestID)
require.True(t, result.OpenAIWSMode)
require.Equal(t, 2, result.Usage.InputTokens)
require.Equal(t, 3, result.Usage.OutputTokens)
case <-time.After(2 * time.Second):
t.Fatal("未收到 passthrough turn 结果回调")
}
require.Equal(t, 1, captureDialer.DialCount(), "passthrough 模式应直接建立上游 websocket")
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -15,6 +15,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
@@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
return event, nil
}
func (c *openAIWSCaptureConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
payload, err := c.ReadMessage(ctx)
if err != nil {
return coderws.MessageText, nil, err
}
return coderws.MessageText, payload, nil
}
func (c *openAIWSCaptureConn) WriteFrame(ctx context.Context, _ coderws.MessageType, payload []byte) error {
return c.WriteJSON(ctx, json.RawMessage(payload))
}
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
_ = ctx
return nil

View File

@@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
switch mode {
case OpenAIWSIngressModeOff:
return openAIWSHTTPDecision("account_mode_off")
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough:
// continue
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// 历史值兼容:按 ctx_pool 处理。
mode = OpenAIWSIngressModeCtxPool
default:
return openAIWSHTTPDecision("account_mode_off")
}

View File

@@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
},
}
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
})
t.Run("off mode routes to http", func(t *testing.T) {
@@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
require.Equal(t, "account_mode_off", decision.Reason)
})
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) {
legacyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
@@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
})
t.Run("passthrough mode routes to ws v2", func(t *testing.T) {
passthroughAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_passthrough", decision.Reason)
})
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)

View File

@@ -0,0 +1,24 @@
package openai_ws_v2
import (
"context"
)
// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
//
// Reference:
// - Project: caddyserver/caddy (Apache-2.0)
// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
// - Files:
// - modules/caddyhttp/reverseproxy/streaming.go
// - modules/caddyhttp/reverseproxy/reverseproxy.go
func runCaddyStyleRelay(
ctx context.Context,
clientConn FrameConn,
upstreamConn FrameConn,
firstClientMessage []byte,
options RelayOptions,
) (RelayResult, *RelayExit) {
return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options)
}

View File

@@ -0,0 +1,23 @@
package openai_ws_v2
import "context"
// EntryInput 是 passthrough v2 数据面的入口参数。
type EntryInput struct {
Ctx context.Context
ClientConn FrameConn
UpstreamConn FrameConn
FirstClientMessage []byte
Options RelayOptions
}
// RunEntry 是 openai_ws_v2 包对外的统一入口。
func RunEntry(input EntryInput) (RelayResult, *RelayExit) {
return runCaddyStyleRelay(
input.Ctx,
input.ClientConn,
input.UpstreamConn,
input.FirstClientMessage,
input.Options,
)
}

View File

@@ -0,0 +1,29 @@
package openai_ws_v2
import (
"sync/atomic"
)
// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
type MetricsSnapshot struct {
SemanticMutationTotal int64 `json:"semantic_mutation_total"`
UsageParseFailureTotal int64 `json:"usage_parse_failure_total"`
}
var (
// passthrough 路径默认不会做语义改写,该计数通常应保持为 0保留用于未来防御性校验
passthroughSemanticMutationTotal atomic.Int64
passthroughUsageParseFailureTotal atomic.Int64
)
func recordUsageParseFailure() {
passthroughUsageParseFailureTotal.Add(1)
}
// SnapshotMetrics 返回当前 passthrough 指标快照。
func SnapshotMetrics() MetricsSnapshot {
return MetricsSnapshot{
SemanticMutationTotal: passthroughSemanticMutationTotal.Load(),
UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(),
}
}

View File

@@ -0,0 +1,807 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
coderws "github.com/coder/websocket"
"github.com/tidwall/gjson"
)
type FrameConn interface {
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
Close() error
}
type Usage struct {
InputTokens int
OutputTokens int
CacheCreationInputTokens int
CacheReadInputTokens int
}
type RelayResult struct {
RequestModel string
Usage Usage
RequestID string
TerminalEventType string
FirstTokenMs *int
Duration time.Duration
ClientToUpstreamFrames int64
UpstreamToClientFrames int64
DroppedDownstreamFrames int64
}
type RelayTurnResult struct {
RequestModel string
Usage Usage
RequestID string
TerminalEventType string
Duration time.Duration
FirstTokenMs *int
}
type RelayExit struct {
Stage string
Err error
WroteDownstream bool
}
type RelayOptions struct {
WriteTimeout time.Duration
IdleTimeout time.Duration
UpstreamDrainTimeout time.Duration
FirstMessageType coderws.MessageType
OnUsageParseFailure func(eventType string, usageRaw string)
OnTurnComplete func(turn RelayTurnResult)
OnTrace func(event RelayTraceEvent)
Now func() time.Time
}
type RelayTraceEvent struct {
Stage string
Direction string
MessageType string
PayloadBytes int
Graceful bool
WroteDownstream bool
Error string
}
type relayState struct {
usage Usage
requestModel string
lastResponseID string
terminalEventType string
firstTokenMs *int
turnTimingByID map[string]*relayTurnTiming
}
type relayExitSignal struct {
stage string
err error
graceful bool
wroteDownstream bool
}
type observedUpstreamEvent struct {
terminal bool
eventType string
responseID string
usage Usage
duration time.Duration
firstToken *int
}
type relayTurnTiming struct {
startAt time.Time
firstTokenMs *int
}
func Relay(
ctx context.Context,
clientConn FrameConn,
upstreamConn FrameConn,
firstClientMessage []byte,
options RelayOptions,
) (RelayResult, *RelayExit) {
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
if clientConn == nil || upstreamConn == nil {
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
}
if ctx == nil {
ctx = context.Background()
}
nowFn := options.Now
if nowFn == nil {
nowFn = time.Now
}
writeTimeout := options.WriteTimeout
if writeTimeout <= 0 {
writeTimeout = 2 * time.Minute
}
drainTimeout := options.UpstreamDrainTimeout
if drainTimeout <= 0 {
drainTimeout = 1200 * time.Millisecond
}
firstMessageType := options.FirstMessageType
if firstMessageType != coderws.MessageBinary {
firstMessageType = coderws.MessageText
}
startAt := nowFn()
state := &relayState{requestModel: result.RequestModel}
onTrace := options.OnTrace
relayCtx, relayCancel := context.WithCancel(ctx)
defer relayCancel()
lastActivity := atomic.Int64{}
lastActivity.Store(nowFn().UnixNano())
markActivity := func() {
lastActivity.Store(nowFn().UnixNano())
}
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
defer cancel()
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
}
writeClient := func(msgType coderws.MessageType, payload []byte) error {
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
defer cancel()
return clientConn.WriteFrame(writeCtx, msgType, payload)
}
clientToUpstreamFrames := &atomic.Int64{}
upstreamToClientFrames := &atomic.Int64{}
droppedDownstreamFrames := &atomic.Int64{}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_start",
PayloadBytes: len(firstClientMessage),
MessageType: relayMessageTypeString(firstMessageType),
})
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
result.Duration = nowFn().Sub(startAt)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
Error: err.Error(),
})
return result, &RelayExit{Stage: "write_upstream", Err: err}
}
clientToUpstreamFrames.Add(1)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_ok",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
})
markActivity()
exitCh := make(chan relayExitSignal, 3)
dropDownstreamWrites := atomic.Bool{}
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
go runUpstreamToClient(
relayCtx,
upstreamConn,
writeClient,
startAt,
nowFn,
state,
options.OnUsageParseFailure,
options.OnTurnComplete,
&dropDownstreamWrites,
upstreamToClientFrames,
droppedDownstreamFrames,
markActivity,
onTrace,
exitCh,
)
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
firstExit := <-exitCh
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "first_exit",
Direction: relayDirectionFromStage(firstExit.stage),
Graceful: firstExit.graceful,
WroteDownstream: firstExit.wroteDownstream,
Error: relayErrorString(firstExit.err),
})
combinedWroteDownstream := firstExit.wroteDownstream
secondExit := relayExitSignal{graceful: true}
hasSecondExit := false
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
if firstExit.stage == "read_client" && firstExit.graceful {
dropDownstreamWrites.Store(true)
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
} else {
relayCancel()
_ = upstreamConn.Close()
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
}
if hasSecondExit {
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "second_exit",
Direction: relayDirectionFromStage(secondExit.stage),
Graceful: secondExit.graceful,
WroteDownstream: secondExit.wroteDownstream,
Error: relayErrorString(secondExit.err),
})
}
relayCancel()
_ = upstreamConn.Close()
enrichResult(&result, state, nowFn().Sub(startAt))
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
if firstExit.stage == "read_client" && firstExit.graceful {
stage := "client_disconnected"
exitErr := firstExit.err
if hasSecondExit && !secondExit.graceful {
stage = secondExit.stage
exitErr = secondExit.err
}
if exitErr == nil {
exitErr = io.EOF
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(exitErr),
})
return result, &RelayExit{
Stage: stage,
Err: exitErr,
WroteDownstream: combinedWroteDownstream,
}
}
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
_ = clientConn.Close()
return result, nil
}
if !firstExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(firstExit.stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(firstExit.err),
})
return result, &RelayExit{
Stage: firstExit.stage,
Err: firstExit.err,
WroteDownstream: combinedWroteDownstream,
}
}
if hasSecondExit && !secondExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(secondExit.stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(secondExit.err),
})
return result, &RelayExit{
Stage: secondExit.stage,
Err: secondExit.err,
WroteDownstream: combinedWroteDownstream,
}
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
_ = clientConn.Close()
return result, nil
}
func runClientToUpstream(
ctx context.Context,
clientConn FrameConn,
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
markActivity func(),
forwardedFrames *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
for {
msgType, payload, err := clientConn.ReadFrame(ctx)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_client_failed",
Direction: "client_to_upstream",
Error: err.Error(),
Graceful: isDisconnectError(err),
})
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
return
}
markActivity()
if err := writeUpstream(msgType, payload); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_upstream_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
Error: err.Error(),
})
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
return
}
if forwardedFrames != nil {
forwardedFrames.Add(1)
}
markActivity()
}
}
func runUpstreamToClient(
ctx context.Context,
upstreamConn FrameConn,
writeClient func(msgType coderws.MessageType, payload []byte) error,
startAt time.Time,
nowFn func() time.Time,
state *relayState,
onUsageParseFailure func(eventType string, usageRaw string),
onTurnComplete func(turn RelayTurnResult),
dropDownstreamWrites *atomic.Bool,
forwardedFrames *atomic.Int64,
droppedFrames *atomic.Int64,
markActivity func(),
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
wroteDownstream := false
for {
msgType, payload, err := upstreamConn.ReadFrame(ctx)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_upstream_failed",
Direction: "upstream_to_client",
Error: err.Error(),
Graceful: isDisconnectError(err),
WroteDownstream: wroteDownstream,
})
exitCh <- relayExitSignal{
stage: "read_upstream",
err: err,
graceful: isDisconnectError(err),
wroteDownstream: wroteDownstream,
}
return
}
markActivity()
observedEvent := observedUpstreamEvent{}
switch msgType {
case coderws.MessageText:
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
case coderws.MessageBinary:
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
}
emitTurnComplete(onTurnComplete, state, observedEvent)
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
if droppedFrames != nil {
droppedFrames.Add(1)
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "drop_downstream_frame",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
})
if observedEvent.terminal {
exitCh <- relayExitSignal{
stage: "drain_terminal",
graceful: true,
wroteDownstream: wroteDownstream,
}
return
}
markActivity()
continue
}
if err := writeClient(msgType, payload); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_client_failed",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
Error: err.Error(),
})
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
return
}
wroteDownstream = true
if forwardedFrames != nil {
forwardedFrames.Add(1)
}
markActivity()
}
}
func runIdleWatchdog(
ctx context.Context,
nowFn func() time.Time,
idleTimeout time.Duration,
lastActivity *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
if idleTimeout <= 0 {
return
}
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
if checkInterval < time.Second {
checkInterval = time.Second
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
last := time.Unix(0, lastActivity.Load())
if nowFn().Sub(last) < idleTimeout {
continue
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "idle_timeout_triggered",
Direction: "watchdog",
Error: context.DeadlineExceeded.Error(),
})
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
return
}
}
}
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
if onTrace == nil {
return
}
onTrace(event)
}
func relayMessageTypeString(msgType coderws.MessageType) string {
switch msgType {
case coderws.MessageText:
return "text"
case coderws.MessageBinary:
return "binary"
default:
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
}
}
func relayDirectionFromStage(stage string) string {
switch stage {
case "read_client", "write_upstream":
return "client_to_upstream"
case "read_upstream", "write_client", "drain_terminal":
return "upstream_to_client"
case "idle_timeout":
return "watchdog"
default:
return ""
}
}
func relayErrorString(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func observeUpstreamMessage(
state *relayState,
message []byte,
startAt time.Time,
nowFn func() time.Time,
onUsageParseFailure func(eventType string, usageRaw string),
) observedUpstreamEvent {
if state == nil || len(message) == 0 {
return observedUpstreamEvent{}
}
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
eventType := strings.TrimSpace(values[0].String())
if eventType == "" {
return observedUpstreamEvent{}
}
responseID := strings.TrimSpace(values[1].String())
if responseID == "" {
responseID = strings.TrimSpace(values[2].String())
}
// 仅 terminal 事件兜底读取顶层 id避免把 event_id 当成 response_id 关联到 turn。
if responseID == "" && isTerminalEvent(eventType) {
responseID = strings.TrimSpace(values[3].String())
}
now := nowFn()
if state.firstTokenMs == nil && isTokenEvent(eventType) {
ms := int(now.Sub(startAt).Milliseconds())
if ms >= 0 {
state.firstTokenMs = &ms
}
}
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
observed := observedUpstreamEvent{
eventType: eventType,
responseID: responseID,
usage: parsedUsage,
}
if responseID != "" {
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
if ms >= 0 {
turnTiming.firstTokenMs = &ms
}
}
}
if !isTerminalEvent(eventType) {
return observed
}
observed.terminal = true
state.terminalEventType = eventType
if responseID != "" {
state.lastResponseID = responseID
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
duration := now.Sub(turnTiming.startAt)
if duration < 0 {
duration = 0
}
observed.duration = duration
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
}
}
return observed
}
func emitTurnComplete(
onTurnComplete func(turn RelayTurnResult),
state *relayState,
observed observedUpstreamEvent,
) {
if onTurnComplete == nil || !observed.terminal {
return
}
responseID := strings.TrimSpace(observed.responseID)
if responseID == "" {
return
}
requestModel := ""
if state != nil {
requestModel = state.requestModel
}
onTurnComplete(RelayTurnResult{
RequestModel: requestModel,
Usage: observed.usage,
RequestID: responseID,
TerminalEventType: observed.eventType,
Duration: observed.duration,
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
})
}
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
if state == nil {
return nil
}
if state.turnTimingByID == nil {
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
}
timing, ok := state.turnTimingByID[responseID]
if !ok || timing == nil || timing.startAt.IsZero() {
timing = &relayTurnTiming{startAt: now}
state.turnTimingByID[responseID] = timing
return timing
}
return timing
}
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
if state == nil || state.turnTimingByID == nil {
return relayTurnTiming{}, false
}
timing, ok := state.turnTimingByID[responseID]
if !ok || timing == nil {
return relayTurnTiming{}, false
}
delete(state.turnTimingByID, responseID)
return *timing, true
}
func openAIWSRelayCloneIntPtr(v *int) *int {
if v == nil {
return nil
}
cloned := *v
return &cloned
}
func parseUsageAndAccumulate(
state *relayState,
message []byte,
eventType string,
onParseFailure func(eventType string, usageRaw string),
) Usage {
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
return Usage{}
}
usageResult := gjson.GetBytes(message, "response.usage")
if !usageResult.Exists() {
return Usage{}
}
usageRaw := strings.TrimSpace(usageResult.Raw)
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
recordUsageParseFailure()
if onParseFailure != nil {
onParseFailure(eventType, usageRaw)
}
return Usage{}
}
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
inputTokens, inputOK := parseUsageIntField(inputResult, true)
outputTokens, outputOK := parseUsageIntField(outputResult, true)
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
if !inputOK || !outputOK || !cachedOK {
recordUsageParseFailure()
if onParseFailure != nil {
onParseFailure(eventType, usageRaw)
}
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
return Usage{}
}
parsedUsage := Usage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheReadInputTokens: cachedTokens,
}
state.usage.InputTokens += parsedUsage.InputTokens
state.usage.OutputTokens += parsedUsage.OutputTokens
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
return parsedUsage
}
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
if !value.Exists() {
return 0, !required
}
if value.Type != gjson.Number {
return 0, false
}
return int(value.Int()), true
}
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
if result == nil {
return
}
result.Duration = duration
if state == nil {
return
}
result.RequestModel = state.requestModel
result.Usage = state.usage
result.RequestID = state.lastResponseID
result.TerminalEventType = state.terminalEventType
result.FirstTokenMs = state.firstTokenMs
}
func isDisconnectError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
return true
}
switch coderws.CloseStatus(err) {
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
return true
}
message := strings.ToLower(strings.TrimSpace(err.Error()))
if message == "" {
return false
}
return strings.Contains(message, "failed to read frame header: eof") ||
strings.Contains(message, "unexpected eof") ||
strings.Contains(message, "use of closed network connection") ||
strings.Contains(message, "connection reset by peer") ||
strings.Contains(message, "broken pipe")
}
func isTerminalEvent(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
}
}
func shouldParseUsage(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed":
return true
default:
return false
}
}
func isTokenEvent(eventType string) bool {
if eventType == "" {
return false
}
switch eventType {
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
return false
}
if strings.Contains(eventType, ".delta") {
return true
}
if strings.HasPrefix(eventType, "response.output_text") {
return true
}
if strings.HasPrefix(eventType, "response.output") {
return true
}
return eventType == "response.completed" || eventType == "response.done"
}
func minDuration(a, b time.Duration) time.Duration {
if a <= 0 {
return b
}
if b <= 0 {
return a
}
if a < b {
return a
}
return b
}
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
if timeout <= 0 {
timeout = 200 * time.Millisecond
}
select {
case sig := <-exitCh:
return sig, true
case <-time.After(timeout):
return relayExitSignal{}, false
}
}

View File

@@ -0,0 +1,432 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestRunEntry_DelegatesRelay(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
result, relayExit := RunEntry(EntryInput{
Ctx: context.Background(),
ClientConn: clientConn,
UpstreamConn: upstreamConn,
FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`),
})
require.Nil(t, relayExit)
require.Equal(t, "resp_entry", result.RequestID)
}
func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
t.Parallel()
t.Run("read client eof", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn(nil, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
nil,
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_client", sig.stage)
require.True(t, sig.graceful)
})
t.Run("write upstream failed", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
func() {},
nil,
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "write_upstream", sig.stage)
require.False(t, sig.graceful)
})
t.Run("forwarded counter and trace callback", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
forwarded := &atomic.Int64{}
traces := make([]RelayTraceEvent, 0, 2)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
forwarded,
func(event RelayTraceEvent) {
traces = append(traces, event)
},
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_client", sig.stage)
require.Equal(t, int64(1), forwarded.Load())
require.NotEmpty(t, traces)
})
}
func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
t.Parallel()
t.Run("read upstream eof", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(false)
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn(nil, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
nil,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_upstream", sig.stage)
require.True(t, sig.graceful)
})
t.Run("write client failed", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(false)
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
nil,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "write_client", sig.stage)
})
t.Run("drop downstream and stop on terminal", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(true)
dropped := &atomic.Int64{}
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
dropped,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "drain_terminal", sig.stage)
require.True(t, sig.graceful)
require.Equal(t, int64(1), dropped.Load())
})
}
func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
lastActivity := &atomic.Int64{}
lastActivity.Store(time.Now().UnixNano())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh)
select {
case <-exitCh:
t.Fatal("unexpected idle timeout signal")
case <-time.After(200 * time.Millisecond):
}
}
func TestHelperFunctionsCoverage(t *testing.T) {
t.Parallel()
require.Equal(t, "text", relayMessageTypeString(coderws.MessageText))
require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary))
require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(")
require.Equal(t, "", relayErrorString(nil))
require.Equal(t, "x", relayErrorString(errors.New("x")))
require.True(t, isDisconnectError(io.EOF))
require.True(t, isDisconnectError(net.ErrClosed))
require.True(t, isDisconnectError(context.Canceled))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway}))
require.True(t, isDisconnectError(errors.New("broken pipe")))
require.False(t, isDisconnectError(errors.New("unrelated")))
require.True(t, isTokenEvent("response.output_text.delta"))
require.True(t, isTokenEvent("response.output_audio.delta"))
require.True(t, isTokenEvent("response.completed"))
require.False(t, isTokenEvent(""))
require.False(t, isTokenEvent("response.created"))
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second))
require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second))
require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second))
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0))
ch := make(chan relayExitSignal, 1)
ch <- relayExitSignal{stage: "ok"}
sig, ok := waitRelayExit(ch, 10*time.Millisecond)
require.True(t, ok)
require.Equal(t, "ok", sig.stage)
ch <- relayExitSignal{stage: "ok2"}
sig, ok = waitRelayExit(ch, 0)
require.True(t, ok)
require.Equal(t, "ok2", sig.stage)
_, ok = waitRelayExit(ch, 10*time.Millisecond)
require.False(t, ok)
n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true)
require.True(t, ok)
require.Equal(t, 3, n)
_, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true)
require.False(t, ok)
n, ok = parseUsageIntField(gjson.Result{}, false)
require.True(t, ok)
require.Equal(t, 0, n)
_, ok = parseUsageIntField(gjson.Result{}, true)
require.False(t, ok)
}
func TestParseUsageAndEnrichCoverage(t *testing.T) {
t.Parallel()
state := &relayState{}
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil)
require.Equal(t, 0, state.usage.InputTokens)
parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`),
"response.completed",
nil,
)
require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage")
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`),
"response.completed",
nil,
)
require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage")
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
require.Equal(t, 2, state.usage.InputTokens)
require.Equal(t, 1, state.usage.OutputTokens)
require.Equal(t, 1, state.usage.CacheReadInputTokens)
result := &RelayResult{}
enrichResult(result, state, 5*time.Millisecond)
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
require.Equal(t, 5*time.Millisecond, result.Duration)
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
require.Equal(t, 2, state.usage.InputTokens)
enrichResult(nil, state, 0)
}
func TestEmitTurnCompleteCoverage(t *testing.T) {
t.Parallel()
// 非 terminal 事件不应触发。
called := 0
emitTurnComplete(func(turn RelayTurnResult) {
called++
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
terminal: false,
eventType: "response.output_text.delta",
responseID: "resp_ignored",
usage: Usage{InputTokens: 1},
})
require.Equal(t, 0, called)
// 缺少 response_id 时不应触发。
emitTurnComplete(func(turn RelayTurnResult) {
called++
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
terminal: true,
eventType: "response.completed",
})
require.Equal(t, 0, called)
// terminal 且 response_id 存在应该触发state=nil 时 model 为空串。
var got RelayTurnResult
emitTurnComplete(func(turn RelayTurnResult) {
called++
got = turn
}, nil, observedUpstreamEvent{
terminal: true,
eventType: "response.completed",
responseID: "resp_emit",
usage: Usage{InputTokens: 2, OutputTokens: 3},
})
require.Equal(t, 1, called)
require.Equal(t, "resp_emit", got.RequestID)
require.Equal(t, "response.completed", got.TerminalEventType)
require.Equal(t, 2, got.Usage.InputTokens)
require.Equal(t, 3, got.Usage.OutputTokens)
require.Equal(t, "", got.RequestModel)
}
func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) {
t.Parallel()
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure}))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd}))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure}))
require.True(t, isDisconnectError(errors.New("connection reset by peer")))
require.False(t, isDisconnectError(errors.New(" ")))
}
func TestIsTokenEventCoverageBranches(t *testing.T) {
t.Parallel()
require.False(t, isTokenEvent("response.in_progress"))
require.False(t, isTokenEvent("response.output_item.added"))
require.True(t, isTokenEvent("response.output_audio.delta"))
require.True(t, isTokenEvent("response.output"))
require.True(t, isTokenEvent("response.done"))
}
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
t.Parallel()
now := time.Unix(100, 0)
// nil state
require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now))
_, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil")
require.False(t, ok)
state := &relayState{}
timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now)
require.NotNil(t, timing)
require.Equal(t, now, timing.startAt)
// 再次获取返回同一条 timing
timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second))
require.NotNil(t, timing2)
require.Equal(t, now, timing2.startAt)
// 删除存在键
deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a")
require.True(t, ok)
require.Equal(t, now, deleted.startAt)
// 删除不存在键
_, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a")
require.False(t, ok)
}
func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) {
t.Parallel()
state := &relayState{requestModel: "gpt-5"}
startAt := time.Unix(0, 0)
now := startAt
nowFn := func() time.Time {
now = now.Add(5 * time.Millisecond)
return now
}
// 非 terminal仅有顶层 id不应把 event id 当成 response_id。
observed := observeUpstreamMessage(
state,
[]byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`),
startAt,
nowFn,
nil,
)
require.False(t, observed.terminal)
require.Equal(t, "", observed.responseID)
// terminal允许兜底用顶层 id用于兼容少数字段变体
observed = observeUpstreamMessage(
state,
[]byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`),
startAt,
nowFn,
nil,
)
require.True(t, observed.terminal)
require.Equal(t, "resp_fallback", observed.responseID)
}

View File

@@ -0,0 +1,752 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"sync"
"sync/atomic"
"testing"
"time"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
)
type passthroughTestFrame struct {
msgType coderws.MessageType
payload []byte
}
type passthroughTestFrameConn struct {
mu sync.Mutex
writes []passthroughTestFrame
readCh chan passthroughTestFrame
once sync.Once
}
type delayedReadFrameConn struct {
base FrameConn
firstDelay time.Duration
once sync.Once
}
type closeSpyFrameConn struct {
closeCalls atomic.Int32
}
func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn {
c := &passthroughTestFrameConn{
readCh: make(chan passthroughTestFrame, len(frames)+1),
}
for _, frame := range frames {
copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)}
c.readCh <- copied
}
if autoClose {
close(c.readCh)
}
return c
}
func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return coderws.MessageText, nil, ctx.Err()
case frame, ok := <-c.readCh:
if !ok {
return coderws.MessageText, nil, io.EOF
}
return frame.msgType, append([]byte(nil), frame.payload...), nil
}
}
func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
c.mu.Lock()
defer c.mu.Unlock()
c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)})
return nil
}
func (c *passthroughTestFrameConn) Close() error {
c.once.Do(func() {
defer func() { _ = recover() }()
close(c.readCh)
})
return nil
}
func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame {
c.mu.Lock()
defer c.mu.Unlock()
out := make([]passthroughTestFrame, len(c.writes))
copy(out, c.writes)
return out
}
func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.base == nil {
return coderws.MessageText, nil, io.EOF
}
c.once.Do(func() {
if c.firstDelay > 0 {
timer := time.NewTimer(c.firstDelay)
defer timer.Stop()
select {
case <-ctx.Done():
case <-timer.C:
}
}
})
return c.base.ReadFrame(ctx)
}
func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.base == nil {
return io.EOF
}
return c.base.WriteFrame(ctx, msgType, payload)
}
func (c *delayedReadFrameConn) Close() error {
if c == nil || c.base == nil {
return nil
}
return c.base.Close()
}
func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if ctx == nil {
ctx = context.Background()
}
<-ctx.Done()
return coderws.MessageText, nil, ctx.Err()
}
func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}
func (c *closeSpyFrameConn) Close() error {
if c != nil {
c.closeCalls.Add(1)
}
return nil
}
func (c *closeSpyFrameConn) CloseCalls() int32 {
if c == nil {
return 0
}
return c.closeCalls.Load()
}
func TestRelay_BasicRelayAndUsage(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
require.Equal(t, "gpt-5.3-codex", result.RequestModel)
require.Equal(t, "resp_123", result.RequestID)
require.Equal(t, "response.completed", result.TerminalEventType)
require.Equal(t, 7, result.Usage.InputTokens)
require.Equal(t, 3, result.Usage.OutputTokens)
require.Equal(t, 2, result.Usage.CacheReadInputTokens)
require.NotNil(t, result.FirstTokenMs)
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
require.Equal(t, int64(1), result.UpstreamToClientFrames)
require.Equal(t, int64(0), result.DroppedDownstreamFrames)
upstreamWrites := upstreamConn.Writes()
require.Len(t, upstreamWrites, 1)
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload))
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload))
}
func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
upstreamWrites := upstreamConn.Writes()
require.Len(t, upstreamWrites, 1)
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
require.Equal(t, firstPayload, upstreamWrites[0].payload)
}
func TestRelay_UpstreamDisconnect(t *testing.T) {
t.Parallel()
// 上游立即关闭EOF客户端不发送额外帧
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
// 上游 EOF 属于 disconnect标记为 graceful
require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect")
require.Equal(t, "gpt-4o", result.RequestModel)
}
func TestRelay_ClientDisconnect(t *testing.T) {
t.Parallel()
// 客户端立即关闭EOF上游阻塞读取直到 context 取消
clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态")
require.Equal(t, "client_disconnected", relayExit.Stage)
require.Equal(t, "gpt-4o", result.RequestModel)
}
func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, true)
upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`),
},
}, true)
upstreamConn := &delayedReadFrameConn{
base: upstreamBase,
firstDelay: 80 * time.Millisecond,
}
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
UpstreamDrainTimeout: 400 * time.Millisecond,
})
require.NotNil(t, relayExit)
require.Equal(t, "client_disconnected", relayExit.Stage)
require.Equal(t, "resp_drain", result.RequestID)
require.Equal(t, "response.completed", result.TerminalEventType)
require.Equal(t, 6, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 1, result.Usage.CacheReadInputTokens)
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
require.Equal(t, int64(0), result.UpstreamToClientFrames)
require.Equal(t, int64(1), result.DroppedDownstreamFrames)
}
func TestRelay_IdleTimeout(t *testing.T) {
t.Parallel()
// 客户端和上游都不发送帧idle timeout 应触发
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 使用快进时间来加速 idle timeout
now := time.Now()
callCount := 0
nowFn := func() time.Time {
callCount++
// 前几次调用返回正常时间(初始化阶段),之后快进
if callCount <= 5 {
return now
}
return now.Add(time.Hour) // 快进到超时
}
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
IdleTimeout: 2 * time.Second,
Now: nowFn,
})
require.NotNil(t, relayExit, "应因 idle timeout 退出")
require.Equal(t, "idle_timeout", relayExit.Stage)
require.Equal(t, "gpt-4o", result.RequestModel)
}
func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) {
t.Parallel()
clientConn := &closeSpyFrameConn{}
upstreamConn := &closeSpyFrameConn{}
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
now := time.Now()
callCount := 0
nowFn := func() time.Time {
callCount++
if callCount <= 5 {
return now
}
return now.Add(time.Hour)
}
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
IdleTimeout: 2 * time.Second,
Now: nowFn,
})
require.NotNil(t, relayExit, "应因 idle timeout 退出")
require.Equal(t, "idle_timeout", relayExit.Stage)
require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code")
require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1))
}
func TestRelay_NilConnections(t *testing.T) {
t.Parallel()
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx := context.Background()
t.Run("nil client conn", func(t *testing.T) {
upstreamConn := newPassthroughTestFrameConn(nil, true)
_, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{})
require.NotNil(t, relayExit)
require.Equal(t, "relay_init", relayExit.Stage)
require.Contains(t, relayExit.Err.Error(), "nil")
})
t.Run("nil upstream conn", func(t *testing.T) {
clientConn := newPassthroughTestFrameConn(nil, true)
_, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{})
require.NotNil(t, relayExit)
require.Equal(t, "relay_init", relayExit.Stage)
require.Contains(t, relayExit.Err.Error(), "nil")
})
}
func TestRelay_MultipleUpstreamMessages(t *testing.T) {
t.Parallel()
// 上游发送多个事件delta + completed验证多帧中继和 usage 聚合
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
require.Equal(t, "resp_multi", result.RequestID)
require.Equal(t, "response.completed", result.TerminalEventType)
require.Equal(t, 10, result.Usage.InputTokens)
require.Equal(t, 5, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
require.NotNil(t, result.FirstTokenMs)
// 验证所有 3 个上游帧都转发给了客户端
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 3)
}
func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
turns := make([]RelayTurnResult, 0, 2)
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
OnTurnComplete: func(turn RelayTurnResult) {
turns = append(turns, turn)
},
})
require.Nil(t, relayExit)
require.Len(t, turns, 2)
require.Equal(t, "resp_turn_1", turns[0].RequestID)
require.Equal(t, "response.completed", turns[0].TerminalEventType)
require.Equal(t, 2, turns[0].Usage.InputTokens)
require.Equal(t, 1, turns[0].Usage.OutputTokens)
require.Equal(t, "resp_turn_2", turns[1].RequestID)
require.Equal(t, "response.failed", turns[1].TerminalEventType)
require.Equal(t, 3, turns[1].Usage.InputTokens)
require.Equal(t, 4, turns[1].Usage.OutputTokens)
require.Equal(t, 5, result.Usage.InputTokens)
require.Equal(t, 5, result.Usage.OutputTokens)
}
func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
base := time.Unix(0, 0)
var nowTick atomic.Int64
nowFn := func() time.Time {
step := nowTick.Add(1)
return base.Add(time.Duration(step) * 5 * time.Millisecond)
}
var turn RelayTurnResult
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
Now: nowFn,
OnTurnComplete: func(current RelayTurnResult) {
turn = current
},
})
require.Nil(t, relayExit)
require.Equal(t, "resp_metric", turn.RequestID)
require.Equal(t, "response.completed", turn.TerminalEventType)
require.NotNil(t, turn.FirstTokenMs)
require.GreaterOrEqual(t, *turn.FirstTokenMs, 0)
require.Greater(t, turn.Duration.Milliseconds(), int64(0))
require.NotNil(t, result.FirstTokenMs)
require.Greater(t, result.Duration.Milliseconds(), int64(0))
}
func TestRelay_BinaryFramePassthrough(t *testing.T) {
t.Parallel()
// 验证 binary frame 被透传但不进行 usage 解析
binaryPayload := []byte{0x00, 0x01, 0x02, 0x03}
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageBinary,
payload: binaryPayload,
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
// binary frame 不解析 usage
require.Equal(t, 0, result.Usage.InputTokens)
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
require.Equal(t, binaryPayload, clientWrites[0].payload)
}
func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageBinary,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
require.Equal(t, 0, result.Usage.InputTokens)
require.Equal(t, "", result.RequestID)
require.Equal(t, "", result.TerminalEventType)
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
}
func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: errorEvent,
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
require.Equal(t, errorEvent, clientWrites[0].payload)
}
func TestRelay_PreservesFirstMessageType(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
FirstMessageType: coderws.MessageBinary,
})
require.Nil(t, relayExit)
upstreamWrites := upstreamConn.Writes()
require.Len(t, upstreamWrites, 1)
require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType)
require.Equal(t, firstPayload, upstreamWrites[0].payload)
}
func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) {
baseline := SnapshotMetrics().UsageParseFailureTotal
// 上游发送无效 JSON非 usage 格式),不应影响透传
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
// usage 解析失败,值为 0 但不影响透传
require.Equal(t, 0, result.Usage.InputTokens)
require.Equal(t, "response.completed", result.TerminalEventType)
// 帧仍然被转发
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1)
}
func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) {
t.Parallel()
// 上游连接立即关闭,首包写入失败
upstreamConn := newPassthroughTestFrameConn(nil, true)
_ = upstreamConn.Close()
// 覆盖 WriteFrame 使其返回错误
errConn := &errorOnWriteFrameConn{}
clientConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{})
require.NotNil(t, relayExit)
require.Equal(t, "write_upstream", relayExit.Stage)
}
func TestRelay_ContextCanceled(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
// 立即取消 context
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
// context 取消导致写首包失败
require.NotNil(t, relayExit)
}
func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
stages := make([]string, 0, 8)
var stagesMu sync.Mutex
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
OnTrace: func(event RelayTraceEvent) {
stagesMu.Lock()
stages = append(stages, event.Stage)
stagesMu.Unlock()
},
})
require.Nil(t, relayExit)
stagesMu.Lock()
capturedStages := append([]string(nil), stages...)
stagesMu.Unlock()
require.Contains(t, capturedStages, "relay_start")
require.Contains(t, capturedStages, "write_first_message_ok")
require.Contains(t, capturedStages, "first_exit")
require.Contains(t, capturedStages, "relay_complete")
}
func TestRelay_TraceEvents_IdleTimeout(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
now := time.Now()
callCount := 0
nowFn := func() time.Time {
callCount++
if callCount <= 5 {
return now
}
return now.Add(time.Hour)
}
stages := make([]string, 0, 8)
var stagesMu sync.Mutex
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
IdleTimeout: 2 * time.Second,
Now: nowFn,
OnTrace: func(event RelayTraceEvent) {
stagesMu.Lock()
stages = append(stages, event.Stage)
stagesMu.Unlock()
},
})
require.NotNil(t, relayExit)
require.Equal(t, "idle_timeout", relayExit.Stage)
stagesMu.Lock()
capturedStages := append([]string(nil), stages...)
stagesMu.Unlock()
require.Contains(t, capturedStages, "idle_timeout_triggered")
require.Contains(t, capturedStages, "relay_exit")
}
// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。
type errorOnWriteFrameConn struct{}
func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
<-ctx.Done()
return coderws.MessageText, nil, ctx.Err()
}
func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error {
return errors.New("write failed: connection refused")
}
func (c *errorOnWriteFrameConn) Close() error {
return nil
}

View File

@@ -0,0 +1,367 @@
package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type openAIWSClientFrameConn struct {
conn *coderws.Conn
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.conn == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Read(ctx)
}
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Write(ctx, msgType, payload)
}
func (c *openAIWSClientFrameConn) Close() error {
if c == nil || c.conn == nil {
return nil
}
_ = c.conn.Close(coderws.StatusNormalClosure, "")
_ = c.conn.CloseNow()
return nil
}
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
ctx context.Context,
c *gin.Context,
clientConn *coderws.Conn,
account *Account,
token string,
firstClientMessage []byte,
hooks *OpenAIWSIngressHooks,
wsDecision OpenAIWSProtocolDecision,
) error {
if s == nil {
return errors.New("service is nil")
}
if clientConn == nil {
return errors.New("client websocket is nil")
}
if account == nil {
return errors.New("account is nil")
}
if strings.TrimSpace(token) == "" {
return errors.New("token is empty")
}
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
account.ID,
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
openaiwsv2RelayMessageTypeName(coderws.MessageText),
len(firstClientMessage),
)
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
}
wsHost := "-"
wsPath := "-"
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
}
logOpenAIWSV2Passthrough(
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
account.ID,
wsHost,
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
isCodexCLI := false
if c != nil {
isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
}
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
isCodexCLI = true
}
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
dialer := s.getOpenAIWSPassthroughDialer()
if dialer == nil {
return errors.New("openai ws passthrough dialer is nil")
}
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
defer cancelDial()
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
if err != nil {
logOpenAIWSV2Passthrough(
"relay_dial_failed account_id=%d status_code=%d err=%s",
account.ID,
statusCode,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
)
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
}
defer func() {
_ = upstreamConn.Close()
}()
logOpenAIWSV2Passthrough(
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
account.ID,
statusCode,
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
)
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
if !ok {
return errors.New("openai ws passthrough upstream connection does not support frame relay")
}
completedTurns := atomic.Int32{}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
WriteTimeout: s.openAIWSWriteTimeout(),
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
FirstMessageType: coderws.MessageText,
OnUsageParseFailure: func(eventType string, usageRaw string) {
logOpenAIWSV2Passthrough(
"usage_parse_failed event_type=%s usage_raw=%s",
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
)
},
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
turnNo := int(completedTurns.Add(1))
turnResult := &OpenAIForwardResult{
RequestID: turn.RequestID,
Usage: OpenAIUsage{
InputTokens: turn.Usage.InputTokens,
OutputTokens: turn.Usage.OutputTokens,
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
},
Model: turn.RequestModel,
Stream: true,
OpenAIWSMode: true,
Duration: turn.Duration,
FirstTokenMs: turn.FirstTokenMs,
}
logOpenAIWSV2Passthrough(
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
account.ID,
turnNo,
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
turnResult.Duration.Milliseconds(),
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
turnResult.Usage.InputTokens,
turnResult.Usage.OutputTokens,
turnResult.Usage.CacheReadInputTokens,
)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turnNo, turnResult, nil)
}
},
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
logOpenAIWSV2Passthrough(
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
account.ID,
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
event.PayloadBytes,
event.Graceful,
event.WroteDownstream,
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
)
},
},
})
result := &OpenAIForwardResult{
RequestID: relayResult.RequestID,
Usage: OpenAIUsage{
InputTokens: relayResult.Usage.InputTokens,
OutputTokens: relayResult.Usage.OutputTokens,
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
},
Model: relayResult.RequestModel,
Stream: true,
OpenAIWSMode: true,
Duration: relayResult.Duration,
FirstTokenMs: relayResult.FirstTokenMs,
}
turnCount := int(completedTurns.Load())
if relayExit == nil {
logOpenAIWSV2Passthrough(
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
account.ID,
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
result.Duration.Milliseconds(),
relayResult.ClientToUpstreamFrames,
relayResult.UpstreamToClientFrames,
relayResult.DroppedDownstreamFrames,
turnCount,
)
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(1, result, nil)
}
return nil
}
logOpenAIWSV2Passthrough(
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
account.ID,
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
relayExit.WroteDownstream,
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
result.Duration.Milliseconds(),
relayResult.ClientToUpstreamFrames,
relayResult.UpstreamToClientFrames,
relayResult.DroppedDownstreamFrames,
turnCount,
)
relayErr := relayExit.Err
if relayExit.Stage == "idle_timeout" {
relayErr = NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"client websocket idle timeout",
relayErr,
)
}
turnErr := wrapOpenAIWSIngressTurnError(
relayExit.Stage,
relayErr,
relayExit.WroteDownstream,
)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turnCount+1, nil, turnErr)
}
return turnErr
}
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
err error,
statusCode int,
handshakeHeaders http.Header,
) error {
if err == nil {
return nil
}
wrappedErr := err
var dialErr *openAIWSDialError
if !errors.As(err, &dialErr) {
wrappedErr = &openAIWSDialError{
StatusCode: statusCode,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: err,
}
}
if errors.Is(err, context.Canceled) {
return err
}
if errors.Is(err, context.DeadlineExceeded) {
return NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket connect timeout",
wrappedErr,
)
}
if statusCode == http.StatusTooManyRequests {
return NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket is busy, please retry later",
wrappedErr,
)
}
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream websocket authentication failed",
wrappedErr,
)
}
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream websocket handshake rejected",
wrappedErr,
)
}
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
}
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
switch msgType {
case coderws.MessageText:
return "text"
case coderws.MessageBinary:
return "binary"
default:
return fmt.Sprintf("unknown(%d)", msgType)
}
}
func relayErrorText(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
if firstTokenMs == nil {
return -1
}
return *firstTokenMs
}
func logOpenAIWSV2Passthrough(format string, args ...any) {
logger.LegacyPrintf(
"service.openai_ws_v2",
"[OpenAI WS v2 passthrough] %s "+format,
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
)
}

View File

@@ -1091,6 +1091,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
if !account.IsTempUnschedulableEnabled() {
return false
}
// 401 首次命中可临时不可调度(给 token 刷新窗口);
// 若历史上已因 401 进入过临时不可调度,则本次应升级为 error返回 false 交由默认错误逻辑处理)。
if statusCode == http.StatusUnauthorized {
reason := account.TempUnschedulableReason
// 缓存可能没有 reason从 DB 回退读取
if reason == "" {
if dbAcc, err := s.accountRepo.GetByID(ctx, account.ID); err == nil && dbAcc != nil {
reason = dbAcc.TempUnschedulableReason
}
}
if wasTempUnschedByStatusCode(reason, statusCode) {
slog.Info("401_escalated_to_error", "account_id", account.ID,
"reason", "previous temp-unschedulable was also 401")
return false
}
}
rules := account.GetTempUnschedulableRules()
if len(rules) == 0 {
return false
@@ -1122,6 +1138,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
return false
}
func wasTempUnschedByStatusCode(reason string, statusCode int) bool {
if statusCode <= 0 {
return false
}
reason = strings.TrimSpace(reason)
if reason == "" {
return false
}
var state TempUnschedState
if err := json.Unmarshal([]byte(reason), &state); err != nil {
return false
}
return state.StatusCode == statusCode
}
func matchTempUnschedKeyword(bodyLower string, keywords []string) string {
if bodyLower == "" {
return ""

View File

@@ -0,0 +1,119 @@
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// dbFallbackRepoStub extends errorPolicyRepoStub with a configurable DB account
// returned by GetByID, simulating cache miss + DB fallback.
type dbFallbackRepoStub struct {
errorPolicyRepoStub
dbAccount *Account // returned by GetByID when non-nil
}
func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
if r.dbAccount != nil && r.dbAccount.ID == id {
return r.dbAccount, nil
}
return nil, nil // not found, no error
}
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "", // cache miss — reason is empty
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
}
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason,
// DB also has no previous 401 record → should NOT escalate (first hit → temp unscheduled).
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 21,
TempUnschedulableReason: "", // DB also empty
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 21,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with no DB record should temp-unschedule")
}
func TestCheckErrorPolicy_401_DBFallback_DBError_FirstHit(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason,
// DB lookup returns nil (not found) → should treat as first hit → temp unscheduled.
repo := &dbFallbackRepoStub{
dbAccount: nil, // GetByID returns nil, nil
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 22,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with DB not found should temp-unschedule")
}

View File

@@ -0,0 +1,51 @@
package service
import (
"context"
"time"
)
// ScheduledTestPlan represents a scheduled test plan domain model.
type ScheduledTestPlan struct {
ID int64 `json:"id"`
AccountID int64 `json:"account_id"`
ModelID string `json:"model_id"`
CronExpression string `json:"cron_expression"`
Enabled bool `json:"enabled"`
MaxResults int `json:"max_results"`
LastRunAt *time.Time `json:"last_run_at"`
NextRunAt *time.Time `json:"next_run_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ScheduledTestResult represents a single test execution result.
type ScheduledTestResult struct {
ID int64 `json:"id"`
PlanID int64 `json:"plan_id"`
Status string `json:"status"`
ResponseText string `json:"response_text"`
ErrorMessage string `json:"error_message"`
LatencyMs int64 `json:"latency_ms"`
StartedAt time.Time `json:"started_at"`
FinishedAt time.Time `json:"finished_at"`
CreatedAt time.Time `json:"created_at"`
}
// ScheduledTestPlanRepository defines the data access interface for test plans.
type ScheduledTestPlanRepository interface {
Create(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error)
GetByID(ctx context.Context, id int64) (*ScheduledTestPlan, error)
ListByAccountID(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error)
ListDue(ctx context.Context, now time.Time) ([]*ScheduledTestPlan, error)
Update(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error)
Delete(ctx context.Context, id int64) error
UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error
}
// ScheduledTestResultRepository defines the data access interface for test results.
type ScheduledTestResultRepository interface {
Create(ctx context.Context, result *ScheduledTestResult) (*ScheduledTestResult, error)
ListByPlanID(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error)
PruneOldResults(ctx context.Context, planID int64, keepCount int) error
}

View File

@@ -0,0 +1,139 @@
package service
import (
"context"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/robfig/cron/v3"
)
const scheduledTestDefaultMaxWorkers = 10
// ScheduledTestRunnerService periodically scans due test plans and executes them.
type ScheduledTestRunnerService struct {
planRepo ScheduledTestPlanRepository
scheduledSvc *ScheduledTestService
accountTestSvc *AccountTestService
cfg *config.Config
cron *cron.Cron
startOnce sync.Once
stopOnce sync.Once
}
// NewScheduledTestRunnerService creates a new runner.
func NewScheduledTestRunnerService(
planRepo ScheduledTestPlanRepository,
scheduledSvc *ScheduledTestService,
accountTestSvc *AccountTestService,
cfg *config.Config,
) *ScheduledTestRunnerService {
return &ScheduledTestRunnerService{
planRepo: planRepo,
scheduledSvc: scheduledSvc,
accountTestSvc: accountTestSvc,
cfg: cfg,
}
}
// Start begins the cron ticker (every minute).
func (s *ScheduledTestRunnerService) Start() {
if s == nil {
return
}
s.startOnce.Do(func() {
loc := time.Local
if s.cfg != nil {
if parsed, err := time.LoadLocation(s.cfg.Timezone); err == nil && parsed != nil {
loc = parsed
}
}
c := cron.New(cron.WithParser(scheduledTestCronParser), cron.WithLocation(loc))
_, err := c.AddFunc("* * * * *", func() { s.runScheduled() })
if err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] not started (invalid schedule): %v", err)
return
}
s.cron = c
s.cron.Start()
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] started (tick=every minute)")
})
}
// Stop gracefully shuts down the cron scheduler.
func (s *ScheduledTestRunnerService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.cron != nil {
ctx := s.cron.Stop()
select {
case <-ctx.Done():
case <-time.After(3 * time.Second):
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] cron stop timed out")
}
}
})
}
func (s *ScheduledTestRunnerService) runScheduled() {
// Delay 10s so execution lands at ~:10 of each minute instead of :00.
time.Sleep(10 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
now := time.Now()
plans, err := s.planRepo.ListDue(ctx, now)
if err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] ListDue error: %v", err)
return
}
if len(plans) == 0 {
return
}
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] found %d due plans", len(plans))
sem := make(chan struct{}, scheduledTestDefaultMaxWorkers)
var wg sync.WaitGroup
for _, plan := range plans {
sem <- struct{}{}
wg.Add(1)
go func(p *ScheduledTestPlan) {
defer wg.Done()
defer func() { <-sem }()
s.runOnePlan(ctx, p)
}(plan)
}
wg.Wait()
}
func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *ScheduledTestPlan) {
result, err := s.accountTestSvc.RunTestBackground(ctx, plan.AccountID, plan.ModelID)
if err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d RunTestBackground error: %v", plan.ID, err)
return
}
if err := s.scheduledSvc.SaveResult(ctx, plan.ID, plan.MaxResults, result); err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d SaveResult error: %v", plan.ID, err)
}
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
if err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d computeNextRun error: %v", plan.ID, err)
return
}
if err := s.planRepo.UpdateAfterRun(ctx, plan.ID, time.Now(), nextRun); err != nil {
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d UpdateAfterRun error: %v", plan.ID, err)
}
}

View File

@@ -0,0 +1,94 @@
package service
import (
"context"
"fmt"
"time"
"github.com/robfig/cron/v3"
)
var scheduledTestCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
// ScheduledTestService provides CRUD operations for scheduled test plans and results.
type ScheduledTestService struct {
planRepo ScheduledTestPlanRepository
resultRepo ScheduledTestResultRepository
}
// NewScheduledTestService creates a new ScheduledTestService.
func NewScheduledTestService(
planRepo ScheduledTestPlanRepository,
resultRepo ScheduledTestResultRepository,
) *ScheduledTestService {
return &ScheduledTestService{
planRepo: planRepo,
resultRepo: resultRepo,
}
}
// CreatePlan validates the cron expression, computes next_run_at, and persists the plan.
func (s *ScheduledTestService) CreatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) {
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
if err != nil {
return nil, fmt.Errorf("invalid cron expression: %w", err)
}
plan.NextRunAt = &nextRun
if plan.MaxResults <= 0 {
plan.MaxResults = 50
}
return s.planRepo.Create(ctx, plan)
}
// GetPlan retrieves a plan by ID.
func (s *ScheduledTestService) GetPlan(ctx context.Context, id int64) (*ScheduledTestPlan, error) {
return s.planRepo.GetByID(ctx, id)
}
// ListPlansByAccount returns all plans for a given account.
func (s *ScheduledTestService) ListPlansByAccount(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error) {
return s.planRepo.ListByAccountID(ctx, accountID)
}
// UpdatePlan validates cron and updates the plan.
func (s *ScheduledTestService) UpdatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) {
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
if err != nil {
return nil, fmt.Errorf("invalid cron expression: %w", err)
}
plan.NextRunAt = &nextRun
return s.planRepo.Update(ctx, plan)
}
// DeletePlan removes a plan and its results (via CASCADE).
func (s *ScheduledTestService) DeletePlan(ctx context.Context, id int64) error {
return s.planRepo.Delete(ctx, id)
}
// ListResults returns the most recent results for a plan.
func (s *ScheduledTestService) ListResults(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error) {
if limit <= 0 {
limit = 50
}
return s.resultRepo.ListByPlanID(ctx, planID, limit)
}
// SaveResult inserts a result and prunes old entries beyond maxResults.
func (s *ScheduledTestService) SaveResult(ctx context.Context, planID int64, maxResults int, result *ScheduledTestResult) error {
result.PlanID = planID
if _, err := s.resultRepo.Create(ctx, result); err != nil {
return err
}
return s.resultRepo.PruneOldResults(ctx, planID, maxResults)
}
func computeNextRun(cronExpr string, from time.Time) (time.Time, error) {
sched, err := scheduledTestCronParser.Parse(cronExpr)
if err != nil {
return time.Time{}, err
}
return sched.Next(from), nil
}

View File

@@ -274,6 +274,26 @@ func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Co
return svc
}
// ProvideScheduledTestService creates ScheduledTestService.
func ProvideScheduledTestService(
planRepo ScheduledTestPlanRepository,
resultRepo ScheduledTestResultRepository,
) *ScheduledTestService {
return NewScheduledTestService(planRepo, resultRepo)
}
// ProvideScheduledTestRunnerService creates and starts ScheduledTestRunnerService.
func ProvideScheduledTestRunnerService(
planRepo ScheduledTestPlanRepository,
scheduledSvc *ScheduledTestService,
accountTestSvc *AccountTestService,
cfg *config.Config,
) *ScheduledTestRunnerService {
svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, cfg)
svc.Start()
return svc
}
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
func ProvideOpsScheduledReportService(
opsService *OpsService,
@@ -380,4 +400,6 @@ var ProviderSet = wire.NewSet(
ProvideIdempotencyCoordinator,
ProvideSystemOperationLockService,
ProvideIdempotencyCleanupService,
ProvideScheduledTestService,
ProvideScheduledTestRunnerService,
)

View File

@@ -0,0 +1,30 @@
-- 066_add_scheduled_test_tables.sql
-- Scheduled account test plans and results
CREATE TABLE IF NOT EXISTS scheduled_test_plans (
id BIGSERIAL PRIMARY KEY,
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
model_id VARCHAR(100) NOT NULL DEFAULT '',
cron_expression VARCHAR(100) NOT NULL DEFAULT '*/30 * * * *',
enabled BOOLEAN NOT NULL DEFAULT true,
max_results INT NOT NULL DEFAULT 50,
last_run_at TIMESTAMPTZ,
next_run_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_stp_account_id ON scheduled_test_plans(account_id);
CREATE INDEX IF NOT EXISTS idx_stp_enabled_next_run ON scheduled_test_plans(enabled, next_run_at) WHERE enabled = true;
CREATE TABLE IF NOT EXISTS scheduled_test_results (
id BIGSERIAL PRIMARY KEY,
plan_id BIGINT NOT NULL REFERENCES scheduled_test_plans(id) ON DELETE CASCADE,
status VARCHAR(20) NOT NULL DEFAULT 'success',
response_text TEXT NOT NULL DEFAULT '',
error_message TEXT NOT NULL DEFAULT '',
latency_ms BIGINT NOT NULL DEFAULT 0,
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
finished_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_str_plan_created ON scheduled_test_results(plan_id, created_at DESC);

View File

@@ -209,8 +209,9 @@ gateway:
openai_ws:
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
mode_router_v2_enabled: false
# ingress 默认模式off|shared|dedicated(仅 mode_router_v2_enabled=true 生效)
ingress_mode_default: shared
# ingress 默认模式off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效)
# 兼容旧值shared/dedicated 会按 ctx_pool 处理。
ingress_mode_default: ctx_pool
# 全局总开关,默认 true关闭时所有请求保持原有 HTTP/SSE 路由
enabled: true
# 按账号类型细分开关

View File

@@ -22,6 +22,7 @@ import opsAPI from './ops'
import errorPassthroughAPI from './errorPassthrough'
import dataManagementAPI from './dataManagement'
import apiKeysAPI from './apiKeys'
import scheduledTestsAPI from './scheduledTests'
/**
* Unified admin API object for convenient access
@@ -45,7 +46,8 @@ export const adminAPI = {
ops: opsAPI,
errorPassthrough: errorPassthroughAPI,
dataManagement: dataManagementAPI,
apiKeys: apiKeysAPI
apiKeys: apiKeysAPI,
scheduledTests: scheduledTestsAPI
}
export {
@@ -67,7 +69,8 @@ export {
opsAPI,
errorPassthroughAPI,
dataManagementAPI,
apiKeysAPI
apiKeysAPI,
scheduledTestsAPI
}
export default adminAPI

View File

@@ -0,0 +1,85 @@
/**
* Admin Scheduled Tests API endpoints
* Handles scheduled test plan management for account connectivity monitoring
*/
import { apiClient } from '../client'
import type {
ScheduledTestPlan,
ScheduledTestResult,
CreateScheduledTestPlanRequest,
UpdateScheduledTestPlanRequest
} from '@/types'
/**
* List all scheduled test plans for an account
* @param accountId - Account ID
* @returns List of scheduled test plans
*/
export async function listByAccount(accountId: number): Promise<ScheduledTestPlan[]> {
const { data } = await apiClient.get<ScheduledTestPlan[]>(
`/admin/accounts/${accountId}/scheduled-test-plans`
)
return data ?? []
}
/**
* Create a new scheduled test plan
* @param req - Plan creation request
* @returns Created plan
*/
export async function create(req: CreateScheduledTestPlanRequest): Promise<ScheduledTestPlan> {
const { data } = await apiClient.post<ScheduledTestPlan>(
'/admin/scheduled-test-plans',
req
)
return data
}
/**
* Update an existing scheduled test plan
* @param id - Plan ID
* @param req - Fields to update
* @returns Updated plan
*/
export async function update(id: number, req: UpdateScheduledTestPlanRequest): Promise<ScheduledTestPlan> {
const { data } = await apiClient.put<ScheduledTestPlan>(
`/admin/scheduled-test-plans/${id}`,
req
)
return data
}
/**
* Delete a scheduled test plan
* @param id - Plan ID
*/
export async function deletePlan(id: number): Promise<void> {
await apiClient.delete(`/admin/scheduled-test-plans/${id}`)
}
/**
* List test results for a plan
* @param planId - Plan ID
* @param limit - Optional max number of results to return
* @returns List of test results
*/
export async function listResults(planId: number, limit?: number): Promise<ScheduledTestResult[]> {
const { data } = await apiClient.get<ScheduledTestResult[]>(
`/admin/scheduled-test-plans/${planId}/results`,
{
params: limit ? { limit } : undefined
}
)
return data ?? []
}
export const scheduledTestsAPI = {
listByAccount,
create,
update,
delete: deletePlan,
listResults
}
export default scheduledTestsAPI

View File

@@ -1807,7 +1807,7 @@
</div>
</div>
<!-- OpenAI WS Mode 三态off/shared/dedicated -->
<!-- OpenAI WS Mode 三态off/ctx_pool/passthrough -->
<div
v-if="form.platform === 'openai' && (accountCategory === 'oauth-based' || accountCategory === 'apikey')"
class="border-t border-gray-200 pt-4 dark:border-dark-600"
@@ -1819,7 +1819,7 @@
{{ t('admin.accounts.openai.wsModeDesc') }}
</p>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.wsModeConcurrencyHint') }}
{{ t(openAIWSModeConcurrencyHintKey) }}
</p>
</div>
<div class="w-52">
@@ -2341,10 +2341,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import {
OPENAI_WS_MODE_DEDICATED,
// OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_SHARED,
OPENAI_WS_MODE_PASSTHROUGH,
isOpenAIWSModeEnabled,
resolveOpenAIWSModeConcurrencyHintKey,
type OpenAIWSMode
} from '@/utils/openaiWsMode'
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
@@ -2541,8 +2542,9 @@ const geminiSelectedTier = computed(() => {
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
{ value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') },
{ value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') }
// TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
// { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
])
const openaiResponsesWebSocketV2Mode = computed({
@@ -2561,6 +2563,10 @@ const openaiResponsesWebSocketV2Mode = computed({
}
})
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value)
)
const isOpenAIModelRestrictionDisabled = computed(() =>
form.platform === 'openai' && openaiPassthroughEnabled.value
)
@@ -3180,10 +3186,13 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
}
const extra: Record<string, unknown> = { ...(base || {}) }
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
if (accountCategory.value === 'oauth-based') {
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
} else if (accountCategory.value === 'apikey') {
extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
}
// 清理兼容旧键,统一改用分类型开关。
delete extra.responses_websockets_v2_enabled
delete extra.openai_ws_enabled

View File

@@ -708,7 +708,7 @@
</div>
</div>
<!-- OpenAI WS Mode 三态off/shared/dedicated -->
<!-- OpenAI WS Mode 三态off/ctx_pool/passthrough -->
<div
v-if="account?.platform === 'openai' && (account?.type === 'oauth' || account?.type === 'apikey')"
class="border-t border-gray-200 pt-4 dark:border-dark-600"
@@ -720,7 +720,7 @@
{{ t('admin.accounts.openai.wsModeDesc') }}
</p>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.wsModeConcurrencyHint') }}
{{ t(openAIWSModeConcurrencyHintKey) }}
</p>
</div>
<div class="w-52">
@@ -1273,10 +1273,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import {
OPENAI_WS_MODE_DEDICATED,
// OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_SHARED,
OPENAI_WS_MODE_PASSTHROUGH,
isOpenAIWSModeEnabled,
resolveOpenAIWSModeConcurrencyHintKey,
type OpenAIWSMode,
resolveOpenAIWSModeFromExtra
} from '@/utils/openaiWsMode'
@@ -1387,8 +1388,9 @@ const codexCLIOnlyEnabled = ref(false)
const anthropicPassthroughEnabled = ref(false)
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
{ value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') },
{ value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') }
// TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
// { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
])
const openaiResponsesWebSocketV2Mode = computed({
get: () => {
@@ -1405,6 +1407,9 @@ const openaiResponsesWebSocketV2Mode = computed({
openaiOAuthResponsesWebSocketV2Mode.value = mode
}
})
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value)
)
const isOpenAIModelRestrictionDisabled = computed(() =>
props.account?.platform === 'openai' && openaiPassthroughEnabled.value
)
@@ -2248,10 +2253,13 @@ const handleSubmit = async () => {
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
const newExtra: Record<string, unknown> = { ...currentExtra }
const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true
newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
if (props.account.type === 'oauth') {
newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
} else if (props.account.type === 'apikey') {
newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
}
delete newExtra.responses_websockets_v2_enabled
delete newExtra.openai_ws_enabled
if (openaiPassthroughEnabled.value) {

View File

@@ -18,6 +18,10 @@
<Icon name="chart" size="sm" class="text-indigo-500" />
{{ t('admin.accounts.viewStats') }}
</button>
<button @click="$emit('schedule', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm hover:bg-gray-100 dark:hover:bg-dark-700">
<Icon name="clock" size="sm" class="text-orange-500" />
{{ t('admin.scheduledTests.schedule') }}
</button>
<template v-if="account.type === 'oauth' || account.type === 'setup-token'">
<button @click="$emit('reauth', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm text-blue-600 hover:bg-gray-100 dark:hover:bg-dark-700">
<Icon name="link" size="sm" />
@@ -51,7 +55,7 @@ import { Icon } from '@/components/icons'
import type { Account } from '@/types'
const props = defineProps<{ show: boolean; account: Account | null; position: { top: number; left: number } | null }>()
const emit = defineEmits(['close', 'test', 'stats', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit'])
const emit = defineEmits(['close', 'test', 'stats', 'schedule', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit'])
const { t } = useI18n()
const isRateLimited = computed(() => {
if (props.account?.rate_limit_reset_at && new Date(props.account.rate_limit_reset_at) > new Date()) {

View File

@@ -25,6 +25,6 @@ const updateStatus = (value: string | number | boolean | null) => { emit('update
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }])
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }])
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }])
const gOpts = computed(() => [{ value: '', label: t('admin.accounts.allGroups') }, ...(props.groups || []).map(g => ({ value: String(g.id), label: g.name }))])
</script>

View File

@@ -0,0 +1,587 @@
<template>
<BaseDialog
:show="show"
:title="t('admin.scheduledTests.title')"
width="wide"
@close="emit('close')"
>
<div class="space-y-4">
<!-- Add Plan Button -->
<div class="flex items-center justify-between">
<p class="text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.scheduledTests.title') }}
</p>
<button
@click="showAddForm = !showAddForm"
class="btn btn-primary flex items-center gap-1.5 text-sm"
>
<Icon name="plus" size="sm" :stroke-width="2" />
{{ t('admin.scheduledTests.addPlan') }}
</button>
</div>
<!-- Add Plan Form -->
<div
v-if="showAddForm"
class="rounded-xl border border-primary-200 bg-primary-50/50 p-4 dark:border-primary-800 dark:bg-primary-900/20"
>
<div class="mb-3 text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.scheduledTests.addPlan') }}
</div>
<div class="grid grid-cols-1 gap-3 sm:grid-cols-2">
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.scheduledTests.model') }}
</label>
<Select
v-model="newPlan.model_id"
:options="modelOptions"
:placeholder="t('admin.scheduledTests.model')"
:searchable="modelOptions.length > 5"
/>
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.scheduledTests.cronExpression') }}
</label>
<Input
v-model="newPlan.cron_expression"
:placeholder="'*/30 * * * *'"
:hint="t('admin.scheduledTests.cronHelp')"
/>
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.scheduledTests.maxResults') }}
</label>
<Input
v-model="newPlan.max_results"
type="number"
placeholder="100"
/>
</div>
<div class="flex items-end">
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
<Toggle v-model="newPlan.enabled" />
{{ t('admin.scheduledTests.enabled') }}
</label>
</div>
</div>
<div class="mt-3 flex justify-end gap-2">
<button
@click="showAddForm = false; resetNewPlan()"
class="rounded-lg bg-gray-100 px-3 py-1.5 text-sm font-medium text-gray-700 transition-colors hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-300 dark:hover:bg-dark-500"
>
{{ t('common.cancel') }}
</button>
<button
@click="handleCreate"
:disabled="!newPlan.model_id || !newPlan.cron_expression || creating"
class="flex items-center gap-1.5 rounded-lg bg-primary-500 px-3 py-1.5 text-sm font-medium text-white transition-colors hover:bg-primary-600 disabled:cursor-not-allowed disabled:opacity-50"
>
<Icon v-if="creating" name="refresh" size="sm" class="animate-spin" :stroke-width="2" />
{{ t('common.save') }}
</button>
</div>
</div>
<!-- Loading State -->
<div v-if="loading" class="flex items-center justify-center py-8">
<Icon name="refresh" size="md" class="animate-spin text-gray-400" :stroke-width="2" />
<span class="ml-2 text-sm text-gray-500">{{ t('common.loading') }}...</span>
</div>
<!-- Empty State -->
<div
v-else-if="plans.length === 0"
class="rounded-xl border border-dashed border-gray-300 py-10 text-center dark:border-dark-600"
>
<Icon name="calendar" size="lg" class="mx-auto mb-2 text-gray-400" :stroke-width="1.5" />
<p class="text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.scheduledTests.noPlans') }}
</p>
</div>
<!-- Plans List -->
<div v-else class="space-y-3">
<div
v-for="plan in plans"
:key="plan.id"
class="rounded-xl border border-gray-200 bg-white transition-all dark:border-dark-600 dark:bg-dark-800"
>
<!-- Plan Header -->
<div
class="flex cursor-pointer items-center justify-between px-4 py-3"
@click="toggleExpand(plan.id)"
>
<div class="flex flex-1 items-center gap-4">
<!-- Model -->
<div class="min-w-0">
<div class="text-sm font-medium text-gray-900 dark:text-gray-100">
{{ plan.model_id }}
</div>
<div class="mt-0.5 font-mono text-xs text-gray-500 dark:text-gray-400">
{{ plan.cron_expression }}
</div>
</div>
<!-- Enabled Toggle -->
<div class="flex items-center gap-1.5" @click.stop>
<Toggle
:model-value="plan.enabled"
@update:model-value="(val: boolean) => handleToggleEnabled(plan, val)"
/>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ plan.enabled ? t('admin.scheduledTests.enabled') : '' }}
</span>
</div>
</div>
<div class="flex items-center gap-3">
<!-- Last Run -->
<div v-if="plan.last_run_at" class="hidden text-right text-xs text-gray-500 dark:text-gray-400 sm:block">
<div>{{ t('admin.scheduledTests.lastRun') }}</div>
<div>{{ formatDateTime(plan.last_run_at) }}</div>
</div>
<!-- Next Run -->
<div v-if="plan.next_run_at" class="hidden text-right text-xs text-gray-500 dark:text-gray-400 sm:block">
<div>{{ t('admin.scheduledTests.nextRun') }}</div>
<div>{{ formatDateTime(plan.next_run_at) }}</div>
</div>
<!-- Actions -->
<div class="flex items-center gap-1" @click.stop>
<button
@click="startEdit(plan)"
class="rounded-lg p-1.5 text-gray-400 transition-colors hover:bg-blue-50 hover:text-blue-500 dark:hover:bg-blue-900/20"
:title="t('admin.scheduledTests.editPlan')"
>
<Icon name="edit" size="sm" :stroke-width="2" />
</button>
<button
@click="confirmDeletePlan(plan)"
class="rounded-lg p-1.5 text-gray-400 transition-colors hover:bg-red-50 hover:text-red-500 dark:hover:bg-red-900/20"
:title="t('admin.scheduledTests.deletePlan')"
>
<Icon name="trash" size="sm" :stroke-width="2" />
</button>
</div>
<!-- Expand indicator -->
<Icon
name="chevronDown"
size="sm"
:class="[
'text-gray-400 transition-transform duration-200',
expandedPlanId === plan.id ? 'rotate-180' : ''
]"
/>
</div>
</div>
<!-- Edit Form -->
<div
v-if="editingPlanId === plan.id"
class="border-t border-blue-100 bg-blue-50/50 px-4 py-3 dark:border-blue-900 dark:bg-blue-900/10"
@click.stop
>
<div class="mb-2 text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.scheduledTests.editPlan') }}
</div>
<div class="grid grid-cols-1 gap-3 sm:grid-cols-2">
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.scheduledTests.model') }}
</label>
<Select
v-model="editForm.model_id"
:options="modelOptions"
:placeholder="t('admin.scheduledTests.model')"
:searchable="modelOptions.length > 5"
/>
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.scheduledTests.cronExpression') }}
</label>
<Input
v-model="editForm.cron_expression"
:placeholder="'*/30 * * * *'"
:hint="t('admin.scheduledTests.cronHelp')"
/>
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.scheduledTests.maxResults') }}
</label>
<Input
v-model="editForm.max_results"
type="number"
placeholder="100"
/>
</div>
<div class="flex items-end">
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
<Toggle v-model="editForm.enabled" />
{{ t('admin.scheduledTests.enabled') }}
</label>
</div>
</div>
<div class="mt-3 flex justify-end gap-2">
<button
@click="cancelEdit"
class="rounded-lg bg-gray-100 px-3 py-1.5 text-sm font-medium text-gray-700 transition-colors hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-300 dark:hover:bg-dark-500"
>
{{ t('common.cancel') }}
</button>
<button
@click="handleEdit"
:disabled="!editForm.model_id || !editForm.cron_expression || updating"
class="flex items-center gap-1.5 rounded-lg bg-primary-500 px-3 py-1.5 text-sm font-medium text-white transition-colors hover:bg-primary-600 disabled:cursor-not-allowed disabled:opacity-50"
>
<Icon v-if="updating" name="refresh" size="sm" class="animate-spin" :stroke-width="2" />
{{ t('common.save') }}
</button>
</div>
</div>
<!-- Expanded Results Section -->
<div
v-if="expandedPlanId === plan.id"
class="border-t border-gray-100 px-4 py-3 dark:border-dark-700"
>
<div class="mb-2 text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.scheduledTests.results') }}
</div>
<!-- Results Loading -->
<div v-if="loadingResults" class="flex items-center justify-center py-4">
<Icon name="refresh" size="sm" class="animate-spin text-gray-400" :stroke-width="2" />
<span class="ml-2 text-xs text-gray-500">{{ t('common.loading') }}...</span>
</div>
<!-- No Results -->
<div
v-else-if="results.length === 0"
class="py-4 text-center text-xs text-gray-500 dark:text-gray-400"
>
{{ t('admin.scheduledTests.noResults') }}
</div>
<!-- Results List -->
<div v-else class="max-h-64 space-y-2 overflow-y-auto">
<div
v-for="result in results"
:key="result.id"
class="rounded-lg border border-gray-100 bg-gray-50 p-3 dark:border-dark-700 dark:bg-dark-900"
>
<div class="flex items-center justify-between">
<div class="flex items-center gap-2">
<!-- Status Badge -->
<span
:class="[
'inline-flex items-center rounded-full px-2 py-0.5 text-xs font-medium',
result.status === 'success'
? 'bg-green-100 text-green-700 dark:bg-green-500/20 dark:text-green-400'
: result.status === 'running'
? 'bg-blue-100 text-blue-700 dark:bg-blue-500/20 dark:text-blue-400'
: 'bg-red-100 text-red-700 dark:bg-red-500/20 dark:text-red-400'
]"
>
{{
result.status === 'success'
? t('admin.scheduledTests.success')
: result.status === 'running'
? t('admin.scheduledTests.running')
: t('admin.scheduledTests.failed')
}}
</span>
<!-- Latency -->
<span v-if="result.latency_ms > 0" class="text-xs text-gray-500 dark:text-gray-400">
{{ result.latency_ms }}ms
</span>
</div>
<!-- Started At -->
<span class="text-xs text-gray-400">
{{ formatDateTime(result.started_at) }}
</span>
</div>
<!-- Response / Error (collapsible) -->
<div v-if="result.error_message" class="mt-2">
<div
class="cursor-pointer text-xs font-medium text-red-600 dark:text-red-400"
@click="toggleResultDetail(result.id)"
>
{{ t('admin.scheduledTests.errorMessage') }}
<Icon
name="chevronDown"
size="sm"
:class="[
'inline transition-transform duration-200',
expandedResultIds.has(result.id) ? 'rotate-180' : ''
]"
/>
</div>
<pre
v-if="expandedResultIds.has(result.id)"
class="mt-1 max-h-32 overflow-auto whitespace-pre-wrap rounded bg-red-50 p-2 text-xs text-red-700 dark:bg-red-900/20 dark:text-red-300"
>{{ result.error_message }}</pre>
</div>
<div v-else-if="result.response_text" class="mt-2">
<div
class="cursor-pointer text-xs font-medium text-gray-600 dark:text-gray-400"
@click="toggleResultDetail(result.id)"
>
{{ t('admin.scheduledTests.responseText') }}
<Icon
name="chevronDown"
size="sm"
:class="[
'inline transition-transform duration-200',
expandedResultIds.has(result.id) ? 'rotate-180' : ''
]"
/>
</div>
<pre
v-if="expandedResultIds.has(result.id)"
class="mt-1 max-h-32 overflow-auto whitespace-pre-wrap rounded bg-gray-100 p-2 text-xs text-gray-700 dark:bg-dark-800 dark:text-gray-300"
>{{ result.response_text }}</pre>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- Delete Confirmation -->
<ConfirmDialog
:show="showDeleteConfirm"
:title="t('admin.scheduledTests.deletePlan')"
:message="t('admin.scheduledTests.confirmDelete')"
:confirm-text="t('common.delete')"
:cancel-text="t('common.cancel')"
:danger="true"
@confirm="handleDelete"
@cancel="showDeleteConfirm = false"
/>
</BaseDialog>
</template>
<script setup lang="ts">
import { ref, reactive, watch } from 'vue'
import { useI18n } from 'vue-i18n'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select, { type SelectOption } from '@/components/common/Select.vue'
import Input from '@/components/common/Input.vue'
import Toggle from '@/components/common/Toggle.vue'
import { Icon } from '@/components/icons'
import { adminAPI } from '@/api/admin'
import { useAppStore } from '@/stores/app'
import { formatDateTime } from '@/utils/format'
import type { ScheduledTestPlan, ScheduledTestResult } from '@/types'
const { t } = useI18n()
const appStore = useAppStore()
const props = defineProps<{
show: boolean
accountId: number | null
modelOptions: SelectOption[]
}>()
const emit = defineEmits<{
(e: 'close'): void
}>()
// State
const loading = ref(false)
const creating = ref(false)
const loadingResults = ref(false)
const plans = ref<ScheduledTestPlan[]>([])
const results = ref<ScheduledTestResult[]>([])
const expandedPlanId = ref<number | null>(null)
const expandedResultIds = reactive(new Set<number>())
const showAddForm = ref(false)
const showDeleteConfirm = ref(false)
const deletingPlan = ref<ScheduledTestPlan | null>(null)
const editingPlanId = ref<number | null>(null)
const updating = ref(false)
const editForm = reactive({
model_id: '' as string,
cron_expression: '' as string,
max_results: '100' as string,
enabled: true
})
const newPlan = reactive({
model_id: '' as string,
cron_expression: '' as string,
max_results: '100' as string,
enabled: true
})
const resetNewPlan = () => {
newPlan.model_id = ''
newPlan.cron_expression = ''
newPlan.max_results = '100'
newPlan.enabled = true
}
// Load plans when dialog opens
watch(
() => props.show,
async (visible) => {
if (visible && props.accountId) {
await loadPlans()
} else {
plans.value = []
results.value = []
expandedPlanId.value = null
expandedResultIds.clear()
showAddForm.value = false
showDeleteConfirm.value = false
}
}
)
const loadPlans = async () => {
if (!props.accountId) return
loading.value = true
try {
plans.value = await adminAPI.scheduledTests.listByAccount(props.accountId)
} catch (error: any) {
appStore.showError(error?.message || 'Failed to load plans')
} finally {
loading.value = false
}
}
const handleCreate = async () => {
if (!props.accountId || !newPlan.model_id || !newPlan.cron_expression) return
creating.value = true
try {
const maxResults = Number(newPlan.max_results) || 100
await adminAPI.scheduledTests.create({
account_id: props.accountId,
model_id: newPlan.model_id,
cron_expression: newPlan.cron_expression,
enabled: newPlan.enabled,
max_results: maxResults
})
appStore.showSuccess(t('admin.scheduledTests.createSuccess'))
showAddForm.value = false
resetNewPlan()
await loadPlans()
} catch (error: any) {
appStore.showError(error?.message || 'Failed to create plan')
} finally {
creating.value = false
}
}
const handleToggleEnabled = async (plan: ScheduledTestPlan, enabled: boolean) => {
try {
const updated = await adminAPI.scheduledTests.update(plan.id, { enabled })
const index = plans.value.findIndex((p) => p.id === plan.id)
if (index !== -1) {
plans.value[index] = updated
}
appStore.showSuccess(t('admin.scheduledTests.updateSuccess'))
} catch (error: any) {
appStore.showError(error?.message || 'Failed to update plan')
}
}
const startEdit = (plan: ScheduledTestPlan) => {
editingPlanId.value = plan.id
editForm.model_id = plan.model_id
editForm.cron_expression = plan.cron_expression
editForm.max_results = String(plan.max_results)
editForm.enabled = plan.enabled
}
const cancelEdit = () => {
editingPlanId.value = null
}
const handleEdit = async () => {
if (!editingPlanId.value || !editForm.model_id || !editForm.cron_expression) return
updating.value = true
try {
const updated = await adminAPI.scheduledTests.update(editingPlanId.value, {
model_id: editForm.model_id,
cron_expression: editForm.cron_expression,
max_results: Number(editForm.max_results) || 100,
enabled: editForm.enabled
})
const index = plans.value.findIndex((p) => p.id === editingPlanId.value)
if (index !== -1) {
plans.value[index] = updated
}
appStore.showSuccess(t('admin.scheduledTests.updateSuccess'))
editingPlanId.value = null
} catch (error: any) {
appStore.showError(error?.message || 'Failed to update plan')
} finally {
updating.value = false
}
}
const confirmDeletePlan = (plan: ScheduledTestPlan) => {
deletingPlan.value = plan
showDeleteConfirm.value = true
}
const handleDelete = async () => {
if (!deletingPlan.value) return
try {
await adminAPI.scheduledTests.delete(deletingPlan.value.id)
appStore.showSuccess(t('admin.scheduledTests.deleteSuccess'))
plans.value = plans.value.filter((p) => p.id !== deletingPlan.value!.id)
if (expandedPlanId.value === deletingPlan.value.id) {
expandedPlanId.value = null
results.value = []
}
} catch (error: any) {
appStore.showError(error?.message || 'Failed to delete plan')
} finally {
showDeleteConfirm.value = false
deletingPlan.value = null
}
}
const toggleExpand = async (planId: number) => {
if (expandedPlanId.value === planId) {
expandedPlanId.value = null
results.value = []
expandedResultIds.clear()
return
}
expandedPlanId.value = planId
expandedResultIds.clear()
loadingResults.value = true
try {
results.value = await adminAPI.scheduledTests.listResults(planId, 20)
} catch (error: any) {
appStore.showError(error?.message || 'Failed to load results')
results.value = []
} finally {
loadingResults.value = false
}
}
const toggleResultDetail = (resultId: number) => {
if (expandedResultIds.has(resultId)) {
expandedResultIds.delete(resultId)
} else {
expandedResultIds.add(resultId)
}
}
</script>

View File

@@ -63,7 +63,8 @@ const chartColors = computed(() => ({
grid: isDarkMode.value ? '#374151' : '#e5e7eb',
input: '#3b82f6',
output: '#10b981',
cache: '#f59e0b'
cacheCreation: '#f59e0b',
cacheRead: '#06b6d4'
}))
const chartData = computed(() => {
@@ -89,10 +90,18 @@ const chartData = computed(() => {
tension: 0.3
},
{
label: 'Cache',
data: props.trendData.map((d) => d.cache_tokens),
borderColor: chartColors.value.cache,
backgroundColor: `${chartColors.value.cache}20`,
label: 'Cache Creation',
data: props.trendData.map((d) => d.cache_creation_tokens),
borderColor: chartColors.value.cacheCreation,
backgroundColor: `${chartColors.value.cacheCreation}20`,
fill: true,
tension: 0.3
},
{
label: 'Cache Read',
data: props.trendData.map((d) => d.cache_read_tokens),
borderColor: chartColors.value.cacheRead,
backgroundColor: `${chartColors.value.cacheRead}20`,
fill: true,
tension: 0.3
}

View File

@@ -443,7 +443,22 @@ $env:ANTHROPIC_AUTH_TOKEN="${apiKey}"`
content = ''
}
return [{ path, content }]
const vscodeSettingsPath = activeTab.value === 'unix'
? '~/.claude/settings.json'
: '%userprofile%\\.claude\\settings.json'
const vscodeContent = `{
"env": {
"ANTHROPIC_BASE_URL": "${baseUrl}",
"ANTHROPIC_AUTH_TOKEN": "${apiKey}",
"CLAUDE_CODE_ATTRIBUTION_HEADER": "0"
}
}`
return [
{ path, content },
{ path: vscodeSettingsPath, content: vscodeContent, hint: 'VSCode Claude Code' }
]
}
function generateGeminiCliContent(baseUrl: string, apiKey: string): FileConfig {
@@ -496,16 +511,16 @@ function generateOpenAIFiles(baseUrl: string, apiKey: string): FileConfig[] {
const configDir = isWindows ? '%userprofile%\\.codex' : '~/.codex'
// config.toml content
const configContent = `model_provider = "sub2api"
const configContent = `model_provider = "OpenAI"
model = "gpt-5.3-codex"
model_reasoning_effort = "high"
network_access = "enabled"
review_model = "gpt-5.3-codex"
model_reasoning_effort = "xhigh"
disable_response_storage = true
network_access = "enabled"
windows_wsl_setup_acknowledged = true
model_verbosity = "high"
[model_providers.sub2api]
name = "sub2api"
[model_providers.OpenAI]
name = "OpenAI"
base_url = "${baseUrl}"
wire_api = "responses"
requires_openai_auth = true`
@@ -533,16 +548,16 @@ function generateOpenAIWsFiles(baseUrl: string, apiKey: string): FileConfig[] {
const configDir = isWindows ? '%userprofile%\\.codex' : '~/.codex'
// config.toml content with WebSocket v2
const configContent = `model_provider = "sub2api"
const configContent = `model_provider = "OpenAI"
model = "gpt-5.3-codex"
model_reasoning_effort = "high"
network_access = "enabled"
review_model = "gpt-5.3-codex"
model_reasoning_effort = "xhigh"
disable_response_storage = true
network_access = "enabled"
windows_wsl_setup_acknowledged = true
model_verbosity = "high"
[model_providers.sub2api]
name = "sub2api"
[model_providers.OpenAI]
name = "OpenAI"
base_url = "${baseUrl}"
wire_api = "responses"
supports_websockets = true

View File

@@ -110,6 +110,75 @@ export default {
}
},
// Key Usage Query Page
keyUsage: {
title: 'API Key Usage',
subtitle: 'Enter your API Key to view real-time spending and usage status',
placeholder: 'sk-ant-mirror-xxxxxxxxxxxx',
query: 'Query',
querying: 'Querying...',
privacyNote: 'Your Key is processed locally in the browser and will not be stored',
dateRange: 'Date Range:',
dateRangeToday: 'Today',
dateRange7d: '7 Days',
dateRange30d: '30 Days',
dateRangeCustom: 'Custom',
apply: 'Apply',
used: 'Used',
detailInfo: 'Detail Information',
tokenStats: 'Token Statistics',
modelStats: 'Model Usage Statistics',
// Table headers
model: 'Model',
requests: 'Requests',
inputTokens: 'Input Tokens',
outputTokens: 'Output Tokens',
cacheCreationTokens: 'Cache Creation',
cacheReadTokens: 'Cache Read',
totalTokens: 'Total Tokens',
cost: 'Cost',
// Status
quotaMode: 'Key Quota Mode',
walletBalance: 'Wallet Balance',
// Ring card titles
totalQuota: 'Total Quota',
limit5h: '5-Hour Limit',
limitDaily: 'Daily Limit',
limit7d: '7-Day Limit',
limitWeekly: 'Weekly Limit',
limitMonthly: 'Monthly Limit',
// Detail rows
remainingQuota: 'Remaining Quota',
expiresAt: 'Expires At',
todayExpires: '(expires today)',
daysLeft: '({days} days)',
usedQuota: 'Used Quota',
subscriptionType: 'Subscription Type',
subscriptionExpires: 'Subscription Expires',
// Usage stat cells
todayRequests: 'Today Requests',
todayInputTokens: 'Today Input',
todayOutputTokens: 'Today Output',
todayTokens: 'Today Tokens',
todayCacheCreation: 'Today Cache Creation',
todayCacheRead: 'Today Cache Read',
todayCost: 'Today Cost',
rpmTpm: 'RPM / TPM',
totalRequests: 'Total Requests',
totalInputTokens: 'Total Input',
totalOutputTokens: 'Total Output',
totalTokensLabel: 'Total Tokens',
totalCacheCreation: 'Total Cache Creation',
totalCacheRead: 'Total Cache Read',
totalCost: 'Total Cost',
avgDuration: 'Avg Duration',
// Messages
enterApiKey: 'Please enter an API Key',
querySuccess: 'Query successful',
queryFailed: 'Query failed',
queryFailedRetry: 'Query failed, please try again later',
},
// Setup Wizard
setup: {
title: 'Sub2API Setup',
@@ -1787,10 +1856,13 @@ export default {
wsMode: 'WS mode',
wsModeDesc: 'Only applies to the current OpenAI account type.',
wsModeOff: 'Off (off)',
wsModeCtxPool: 'Context Pool (ctx_pool)',
wsModePassthrough: 'Passthrough (passthrough)',
wsModeShared: 'Shared (shared)',
wsModeDedicated: 'Dedicated (dedicated)',
wsModeConcurrencyHint:
'When WS mode is enabled, account concurrency becomes the WS connection pool limit for this account.',
wsModePassthroughHint: 'Passthrough mode does not use the WS connection pool.',
oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode',
oauthResponsesWebsocketsV2Desc:
'Only applies to OpenAI OAuth. This account can use OpenAI WebSocket Mode only when enabled.',
@@ -2351,6 +2423,34 @@ export default {
'This account is not eligible for Antigravity, but API forwarding still works. Use at your own risk.'
},
// Scheduled Tests
scheduledTests: {
title: 'Scheduled Tests',
addPlan: 'Add Plan',
editPlan: 'Edit Plan',
deletePlan: 'Delete Plan',
model: 'Model',
cronExpression: 'Cron Expression',
enabled: 'Enabled',
lastRun: 'Last Run',
nextRun: 'Next Run',
maxResults: 'Max Results',
noPlans: 'No scheduled test plans',
confirmDelete: 'Are you sure you want to delete this plan?',
createSuccess: 'Plan created successfully',
updateSuccess: 'Plan updated successfully',
deleteSuccess: 'Plan deleted successfully',
results: 'Test Results',
noResults: 'No test results yet',
responseText: 'Response',
errorMessage: 'Error',
success: 'Success',
failed: 'Failed',
running: 'Running',
schedule: 'Schedule',
cronHelp: 'Standard 5-field cron expression (e.g., */30 * * * *)'
},
// Proxies
proxies: {
title: 'Proxy Management',

View File

@@ -110,6 +110,75 @@ export default {
}
},
// Key Usage Query Page
keyUsage: {
title: 'API Key 用量查询',
subtitle: '输入您的 API Key 以查看实时消费金额与使用状态',
placeholder: 'sk-ant-mirror-xxxxxxxxxxxx',
query: '查询',
querying: '查询中...',
privacyNote: '您的 Key 仅在浏览器本地处理,不会被存储',
dateRange: '统计范围:',
dateRangeToday: '今日',
dateRange7d: '7 天',
dateRange30d: '30 天',
dateRangeCustom: '自定义',
apply: '应用',
used: '已使用',
detailInfo: '详细信息',
tokenStats: 'Token 统计',
modelStats: '模型用量统计',
// Table headers
model: '模型',
requests: '请求数',
inputTokens: '输入 Tokens',
outputTokens: '输出 Tokens',
cacheCreationTokens: '缓存创建',
cacheReadTokens: '缓存读取',
totalTokens: '总 Tokens',
cost: '费用',
// Status
quotaMode: 'Key 限额模式',
walletBalance: '钱包余额',
// Ring card titles
totalQuota: '总额度',
limit5h: '5 小时限额',
limitDaily: '日限额',
limit7d: '7 天限额',
limitWeekly: '周限额',
limitMonthly: '月限额',
// Detail rows
remainingQuota: '剩余额度',
expiresAt: '过期时间',
todayExpires: '(今日到期)',
daysLeft: '({days} 天)',
usedQuota: '已用额度',
subscriptionType: '订阅类型',
subscriptionExpires: '订阅到期',
// Usage stat cells
todayRequests: '今日请求',
todayInputTokens: '今日输入',
todayOutputTokens: '今日输出',
todayTokens: '今日 Tokens',
todayCacheCreation: '今日缓存创建',
todayCacheRead: '今日缓存读取',
todayCost: '今日费用',
rpmTpm: 'RPM / TPM',
totalRequests: '累计请求',
totalInputTokens: '累计输入',
totalOutputTokens: '累计输出',
totalTokensLabel: '累计 Tokens',
totalCacheCreation: '累计缓存创建',
totalCacheRead: '累计缓存读取',
totalCost: '累计费用',
avgDuration: '平均耗时',
// Messages
enterApiKey: '请输入 API Key',
querySuccess: '查询成功',
queryFailed: '查询失败',
queryFailedRetry: '查询失败,请稍后重试',
},
// Setup Wizard
setup: {
title: 'Sub2API 安装向导',
@@ -1935,9 +2004,12 @@ export default {
wsMode: 'WS mode',
wsModeDesc: '仅对当前 OpenAI 账号类型生效。',
wsModeOff: '关闭off',
wsModeCtxPool: '上下文池ctx_pool',
wsModePassthrough: '透传passthrough',
wsModeShared: '共享shared',
wsModeDedicated: '独享dedicated',
wsModeConcurrencyHint: '启用 WS mode 后,该账号并发数将作为该账号 WS 连接池上限。',
wsModePassthroughHint: 'passthrough 模式不使用 WS 连接池。',
oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode',
oauthResponsesWebsocketsV2Desc:
'仅对 OpenAI OAuth 生效。开启后该账号才允许使用 OpenAI WebSocket Mode 协议。',
@@ -2458,6 +2530,34 @@ export default {
}
},
// Scheduled Tests
scheduledTests: {
title: '定时测试',
addPlan: '添加计划',
editPlan: '编辑计划',
deletePlan: '删除计划',
model: '模型',
cronExpression: 'Cron 表达式',
enabled: '启用',
lastRun: '上次运行',
nextRun: '下次运行',
maxResults: '最大结果数',
noPlans: '暂无定时测试计划',
confirmDelete: '确定要删除此计划吗?',
createSuccess: '计划创建成功',
updateSuccess: '计划更新成功',
deleteSuccess: '计划删除成功',
results: '测试结果',
noResults: '暂无测试结果',
responseText: '响应',
errorMessage: '错误',
success: '成功',
failed: '失败',
running: '运行中',
schedule: '定时测试',
cronHelp: '标准 5 字段 cron 表达式(例如 */30 * * * *'
},
// Proxies Management
proxies: {
title: 'IP管理',

View File

@@ -102,6 +102,15 @@ const routes: RouteRecordRaw[] = [
title: 'Reset Password'
}
},
{
path: '/key-usage',
name: 'KeyUsage',
component: () => import('@/views/KeyUsageView.vue'),
meta: {
requiresAuth: false,
title: 'Key Usage',
}
},
// ==================== User Routes ====================
{

View File

@@ -1098,7 +1098,8 @@ export interface TrendDataPoint {
requests: number
input_tokens: number
output_tokens: number
cache_tokens: number
cache_creation_tokens: number
cache_read_tokens: number
total_tokens: number
cost: number // 标准计费
actual_cost: number // 实际扣除
@@ -1109,6 +1110,8 @@ export interface ModelStat {
requests: number
input_tokens: number
output_tokens: number
cache_creation_tokens: number
cache_read_tokens: number
total_tokens: number
cost: number // 标准计费
actual_cost: number // 实际扣除
@@ -1457,3 +1460,45 @@ export interface TotpLogin2FARequest {
temp_token: string
totp_code: string
}
// ==================== Scheduled Test Types ====================
export interface ScheduledTestPlan {
id: number
account_id: number
model_id: string
cron_expression: string
enabled: boolean
max_results: number
last_run_at: string | null
next_run_at: string | null
created_at: string
updated_at: string
}
export interface ScheduledTestResult {
id: number
plan_id: number
status: string
response_text: string
error_message: string
latency_ms: number
started_at: string
finished_at: string
created_at: string
}
export interface CreateScheduledTestPlanRequest {
account_id: number
model_id: string
cron_expression: string
enabled?: boolean
max_results?: number
}
export interface UpdateScheduledTestPlanRequest {
model_id?: string
cron_expression?: string
enabled?: boolean
max_results?: number
}

View File

@@ -1,31 +1,34 @@
import { describe, expect, it } from 'vitest'
import {
OPENAI_WS_MODE_DEDICATED,
OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_SHARED,
OPENAI_WS_MODE_PASSTHROUGH,
isOpenAIWSModeEnabled,
normalizeOpenAIWSMode,
openAIWSModeFromEnabled,
resolveOpenAIWSModeConcurrencyHintKey,
resolveOpenAIWSModeFromExtra
} from '@/utils/openaiWsMode'
describe('openaiWsMode utils', () => {
it('normalizes mode values', () => {
expect(normalizeOpenAIWSMode('off')).toBe(OPENAI_WS_MODE_OFF)
expect(normalizeOpenAIWSMode(' Shared ')).toBe(OPENAI_WS_MODE_SHARED)
expect(normalizeOpenAIWSMode('DEDICATED')).toBe(OPENAI_WS_MODE_DEDICATED)
expect(normalizeOpenAIWSMode('ctx_pool')).toBe(OPENAI_WS_MODE_CTX_POOL)
expect(normalizeOpenAIWSMode('passthrough')).toBe(OPENAI_WS_MODE_PASSTHROUGH)
expect(normalizeOpenAIWSMode(' Shared ')).toBe(OPENAI_WS_MODE_CTX_POOL)
expect(normalizeOpenAIWSMode('DEDICATED')).toBe(OPENAI_WS_MODE_CTX_POOL)
expect(normalizeOpenAIWSMode('invalid')).toBeNull()
})
it('maps legacy enabled flag to mode', () => {
expect(openAIWSModeFromEnabled(true)).toBe(OPENAI_WS_MODE_SHARED)
expect(openAIWSModeFromEnabled(true)).toBe(OPENAI_WS_MODE_CTX_POOL)
expect(openAIWSModeFromEnabled(false)).toBe(OPENAI_WS_MODE_OFF)
expect(openAIWSModeFromEnabled('true')).toBeNull()
})
it('resolves by mode key first, then enabled, then fallback enabled keys', () => {
const extra = {
openai_oauth_responses_websockets_v2_mode: 'dedicated',
openai_oauth_responses_websockets_v2_mode: 'passthrough',
openai_oauth_responses_websockets_v2_enabled: false,
responses_websockets_v2_enabled: false
}
@@ -34,7 +37,7 @@ describe('openaiWsMode utils', () => {
enabledKey: 'openai_oauth_responses_websockets_v2_enabled',
fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled']
})
expect(mode).toBe(OPENAI_WS_MODE_DEDICATED)
expect(mode).toBe(OPENAI_WS_MODE_PASSTHROUGH)
})
it('falls back to default when nothing is present', () => {
@@ -47,9 +50,21 @@ describe('openaiWsMode utils', () => {
expect(mode).toBe(OPENAI_WS_MODE_OFF)
})
it('treats off as disabled and shared/dedicated as enabled', () => {
it('treats off as disabled and non-off modes as enabled', () => {
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_OFF)).toBe(false)
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_SHARED)).toBe(true)
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_DEDICATED)).toBe(true)
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_CTX_POOL)).toBe(true)
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_PASSTHROUGH)).toBe(true)
})
it('resolves concurrency hint key by mode', () => {
expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_OFF)).toBe(
'admin.accounts.openai.wsModeConcurrencyHint'
)
expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_CTX_POOL)).toBe(
'admin.accounts.openai.wsModeConcurrencyHint'
)
expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_PASSTHROUGH)).toBe(
'admin.accounts.openai.wsModePassthroughHint'
)
})
})

View File

@@ -1,16 +1,16 @@
export const OPENAI_WS_MODE_OFF = 'off'
export const OPENAI_WS_MODE_SHARED = 'shared'
export const OPENAI_WS_MODE_DEDICATED = 'dedicated'
export const OPENAI_WS_MODE_CTX_POOL = 'ctx_pool'
export const OPENAI_WS_MODE_PASSTHROUGH = 'passthrough'
export type OpenAIWSMode =
| typeof OPENAI_WS_MODE_OFF
| typeof OPENAI_WS_MODE_SHARED
| typeof OPENAI_WS_MODE_DEDICATED
| typeof OPENAI_WS_MODE_CTX_POOL
| typeof OPENAI_WS_MODE_PASSTHROUGH
const OPENAI_WS_MODES = new Set<OpenAIWSMode>([
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_SHARED,
OPENAI_WS_MODE_DEDICATED
OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_PASSTHROUGH
])
export interface ResolveOpenAIWSModeOptions {
@@ -23,6 +23,9 @@ export interface ResolveOpenAIWSModeOptions {
export const normalizeOpenAIWSMode = (mode: unknown): OpenAIWSMode | null => {
if (typeof mode !== 'string') return null
const normalized = mode.trim().toLowerCase()
if (normalized === 'shared' || normalized === 'dedicated') {
return OPENAI_WS_MODE_CTX_POOL
}
if (OPENAI_WS_MODES.has(normalized as OpenAIWSMode)) {
return normalized as OpenAIWSMode
}
@@ -31,13 +34,22 @@ export const normalizeOpenAIWSMode = (mode: unknown): OpenAIWSMode | null => {
export const openAIWSModeFromEnabled = (enabled: unknown): OpenAIWSMode | null => {
if (typeof enabled !== 'boolean') return null
return enabled ? OPENAI_WS_MODE_SHARED : OPENAI_WS_MODE_OFF
return enabled ? OPENAI_WS_MODE_CTX_POOL : OPENAI_WS_MODE_OFF
}
export const isOpenAIWSModeEnabled = (mode: OpenAIWSMode): boolean => {
return mode !== OPENAI_WS_MODE_OFF
}
export const resolveOpenAIWSModeConcurrencyHintKey = (
mode: OpenAIWSMode
): 'admin.accounts.openai.wsModeConcurrencyHint' | 'admin.accounts.openai.wsModePassthroughHint' => {
if (mode === OPENAI_WS_MODE_PASSTHROUGH) {
return 'admin.accounts.openai.wsModePassthroughHint'
}
return 'admin.accounts.openai.wsModeConcurrencyHint'
}
export const resolveOpenAIWSModeFromExtra = (
extra: Record<string, unknown> | null | undefined,
options: ResolveOpenAIWSModeOptions

View File

@@ -0,0 +1,870 @@
<template>
<div class="relative flex min-h-screen flex-col bg-gray-50 dark:bg-dark-950">
<!-- Header (same pattern as HomeView) -->
<header class="relative z-20 px-6 py-4">
<nav class="mx-auto flex max-w-6xl items-center justify-between">
<router-link to="/home" class="flex items-center gap-3">
<div class="h-10 w-10 overflow-hidden rounded-xl shadow-md">
<img :src="siteLogo || '/logo.png'" alt="Logo" class="h-full w-full object-contain" />
</div>
<span class="text-lg font-semibold tracking-tight text-gray-900 dark:text-white">{{ siteName }}</span>
</router-link>
<div class="flex items-center gap-3">
<LocaleSwitcher />
<a
v-if="docUrl"
:href="docUrl"
target="_blank"
rel="noopener noreferrer"
class="rounded-lg p-2 text-gray-500 transition-colors hover:bg-gray-100 hover:text-gray-700 dark:text-dark-400 dark:hover:bg-dark-800 dark:hover:text-white"
:title="t('home.viewDocs')"
>
<Icon name="book" size="md" />
</a>
<button
@click="toggleTheme"
class="rounded-lg p-2 text-gray-500 transition-colors hover:bg-gray-100 hover:text-gray-700 dark:text-dark-400 dark:hover:bg-dark-800 dark:hover:text-white"
:title="isDark ? t('home.switchToLight') : t('home.switchToDark')"
>
<Icon v-if="isDark" name="sun" size="md" />
<Icon v-else name="moon" size="md" />
</button>
</div>
</nav>
</header>
<!-- Main Content -->
<main class="flex-1 w-full max-w-5xl mx-auto px-6 py-12">
<!-- Hero -->
<div class="text-center mb-12">
<h1 class="text-3xl sm:text-4xl font-bold tracking-tight mb-3 text-gray-900 dark:text-white">
{{ t('keyUsage.title') }}
</h1>
<p class="text-gray-500 dark:text-dark-400 text-base max-w-md mx-auto">
{{ t('keyUsage.subtitle') }}
</p>
</div>
<!-- Input Section -->
<div class="max-w-xl mx-auto mb-14">
<div class="flex gap-3">
<div class="flex-1 relative">
<div class="absolute left-4 top-1/2 -translate-y-1/2 text-gray-400 dark:text-dark-500">
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect x="3" y="11" width="18" height="11" rx="2" ry="2"/><path d="M7 11V7a5 5 0 0 1 10 0v4"/>
</svg>
</div>
<input
v-model="apiKey"
:type="keyVisible ? 'text' : 'password'"
:placeholder="t('keyUsage.placeholder')"
class="input-ring w-full h-12 pl-12 pr-12 rounded-xl border border-gray-200 bg-white text-sm text-gray-900 placeholder:text-gray-400 transition-all dark:border-dark-700 dark:bg-dark-900 dark:text-white dark:placeholder:text-dark-500"
@keydown.enter="queryKey"
/>
<button
@click="keyVisible = !keyVisible"
class="absolute right-4 top-1/2 -translate-y-1/2 text-gray-400 hover:text-gray-700 dark:text-dark-500 dark:hover:text-white transition-colors"
>
<svg v-if="!keyVisible" class="w-5 h-5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M17.94 17.94A10.07 10.07 0 0 1 12 20c-7 0-11-8-11-8a18.45 18.45 0 0 1 5.06-5.94M9.9 4.24A9.12 9.12 0 0 1 12 4c7 0 11 8 11 8a18.5 18.5 0 0 1-2.16 3.19m-6.72-1.07a3 3 0 1 1-4.24-4.24"/>
<line x1="1" y1="1" x2="23" y2="23"/>
</svg>
<svg v-else class="w-5 h-5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z"/><circle cx="12" cy="12" r="3"/>
</svg>
</button>
</div>
<button
@click="queryKey"
:disabled="isQuerying"
class="h-12 px-7 rounded-xl bg-primary-500 hover:bg-primary-600 text-white font-medium text-sm transition-all active:scale-[0.97] flex items-center gap-2 whitespace-nowrap disabled:opacity-60"
>
<svg v-if="isQuerying" class="w-4 h-4 animate-spin" viewBox="0 0 24 24" fill="none">
<circle cx="12" cy="12" r="10" stroke="currentColor" stroke-width="3" opacity="0.25"/>
<path d="M12 2a10 10 0 0 1 10 10" stroke="currentColor" stroke-width="3" stroke-linecap="round"/>
</svg>
<svg v-else class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round">
<circle cx="11" cy="11" r="8"/><line x1="21" y1="21" x2="16.65" y2="16.65"/>
</svg>
{{ isQuerying ? t('keyUsage.querying') : t('keyUsage.query') }}
</button>
</div>
<p class="text-xs text-gray-400 dark:text-dark-500 mt-3 text-center">
{{ t('keyUsage.privacyNote') }}
</p>
<!-- Date Range Picker -->
<div v-if="showDatePicker" class="mt-4">
<div class="flex flex-wrap items-center gap-2 justify-center">
<span class="text-xs text-gray-500 dark:text-dark-400">{{ t('keyUsage.dateRange') }}</span>
<button
v-for="range in dateRanges"
:key="range.key"
@click="setDateRange(range.key)"
class="text-xs px-3 py-1.5 rounded-lg border transition-all"
:class="currentRange === range.key
? 'bg-primary-500 text-white border-primary-500'
: 'border-gray-200 bg-white text-gray-700 dark:border-dark-700 dark:bg-dark-900 dark:text-dark-200 hover:border-primary-300 dark:hover:border-dark-600'"
>{{ range.label }}</button>
<div v-if="currentRange === 'custom'" class="flex items-center gap-2 ml-1">
<input
v-model="customStartDate"
type="date"
class="input-ring text-xs px-2 py-1.5 rounded-lg border border-gray-200 bg-white text-gray-900 dark:border-dark-700 dark:bg-dark-900 dark:text-white"
/>
<span class="text-xs text-gray-400">-</span>
<input
v-model="customEndDate"
type="date"
class="input-ring text-xs px-2 py-1.5 rounded-lg border border-gray-200 bg-white text-gray-900 dark:border-dark-700 dark:bg-dark-900 dark:text-white"
/>
<button
@click="queryKey"
class="text-xs px-3 py-1.5 rounded-lg bg-primary-500 text-white hover:bg-primary-600"
>{{ t('keyUsage.apply') }}</button>
</div>
</div>
</div>
</div>
<!-- Results Container -->
<div v-if="showResults">
<!-- Loading Skeleton -->
<div v-if="showLoading" class="space-y-6">
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div class="rounded-2xl border border-gray-200 bg-white p-8 dark:border-dark-700 dark:bg-dark-900">
<div class="skeleton h-5 w-24 mb-6"></div>
<div class="flex justify-center"><div class="skeleton w-44 h-44 rounded-full"></div></div>
</div>
<div class="rounded-2xl border border-gray-200 bg-white p-8 dark:border-dark-700 dark:bg-dark-900">
<div class="skeleton h-5 w-24 mb-6"></div>
<div class="flex justify-center"><div class="skeleton w-44 h-44 rounded-full"></div></div>
</div>
</div>
<div class="rounded-2xl border border-gray-200 bg-white p-8 dark:border-dark-700 dark:bg-dark-900">
<div class="skeleton h-5 w-32 mb-6"></div>
<div class="space-y-4">
<div class="skeleton h-4 w-full"></div>
<div class="skeleton h-4 w-3/4"></div>
<div class="skeleton h-4 w-5/6"></div>
<div class="skeleton h-4 w-2/3"></div>
</div>
</div>
</div>
<!-- Result Content -->
<div v-else-if="resultData" class="space-y-6">
<!-- Status Badge -->
<div v-if="statusInfo" class="fade-up flex items-center justify-center mb-2">
<div class="inline-flex items-center gap-2 px-5 py-2.5 rounded-full border border-gray-200 bg-white/90 shadow-sm backdrop-blur-sm dark:border-dark-700 dark:bg-dark-900/90">
<span
class="w-2.5 h-2.5 rounded-full pulse-dot"
:class="statusInfo.isActive ? 'bg-emerald-500' : 'bg-rose-500'"
></span>
<span class="text-sm font-medium text-gray-900 dark:text-white">{{ statusInfo.label }}</span>
<span class="text-xs text-gray-400 dark:text-dark-500">|</span>
<span class="text-xs text-gray-500 dark:text-dark-400">{{ statusInfo.statusText }}</span>
</div>
</div>
<!-- Ring Cards Grid -->
<div v-if="ringItems.length > 0" :class="ringGridClass">
<div
v-for="(ring, i) in ringItems"
:key="i"
class="fade-up rounded-2xl border border-gray-200 bg-white/90 p-8 backdrop-blur-sm transition-all duration-300 hover:shadow-lg dark:border-dark-700 dark:bg-dark-900/90"
:class="`fade-up-delay-${Math.min(i + 1, 4)}`"
>
<div class="flex items-center justify-between mb-6">
<h3 class="text-sm font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">
{{ ring.title }}
</h3>
<!-- Clock icon -->
<svg v-if="ring.iconType === 'clock'" class="w-5 h-5 text-gray-400 dark:text-dark-500" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10"/><polyline points="12 6 12 12 16 14"/>
</svg>
<!-- Calendar icon -->
<svg v-else-if="ring.iconType === 'calendar'" class="w-5 h-5 text-gray-400 dark:text-dark-500" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect x="3" y="4" width="18" height="18" rx="2" ry="2"/><line x1="16" y1="2" x2="16" y2="6"/><line x1="8" y1="2" x2="8" y2="6"/><line x1="3" y1="10" x2="21" y2="10"/>
</svg>
<!-- Dollar icon -->
<svg v-else class="w-5 h-5 text-gray-400 dark:text-dark-500" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<line x1="12" y1="1" x2="12" y2="23"/><path d="M17 5H9.5a3.5 3.5 0 0 0 0 7h5a3.5 3.5 0 0 1 0 7H6"/>
</svg>
</div>
<div class="flex justify-center">
<div class="relative">
<svg class="w-44 h-44" viewBox="0 0 160 160">
<circle cx="80" cy="80" r="68" fill="none" :stroke="ringTrackColor" stroke-width="10"/>
<circle
class="progress-ring"
cx="80" cy="80" r="68" fill="none"
:stroke="`url(#ring-grad-${i})`"
stroke-width="10" stroke-linecap="round"
:stroke-dasharray="CIRCUMFERENCE.toFixed(2)"
:stroke-dashoffset="getRingOffset(ring)"
/>
<defs>
<linearGradient :id="`ring-grad-${i}`" x1="0%" y1="0%" x2="100%" y2="100%">
<stop offset="0%" :stop-color="RING_GRADIENTS[i % 4].from"/>
<stop offset="100%" :stop-color="RING_GRADIENTS[i % 4].to"/>
</linearGradient>
</defs>
</svg>
<div class="absolute inset-0 flex flex-col items-center justify-center">
<template v-if="ring.isBalance">
<span class="text-2xl font-bold tabular-nums" :style="{ color: RING_GRADIENTS[i % 4].from }">
{{ ring.amount }}
</span>
</template>
<template v-else>
<span class="text-3xl font-bold tabular-nums text-gray-900 dark:text-white">
{{ displayPcts[i] ?? 0 }}%
</span>
<span class="text-xs text-gray-500 dark:text-dark-400 mt-0.5">{{ t('keyUsage.used') }}</span>
<span
class="text-sm font-semibold mt-1 tabular-nums"
:style="{ color: RING_GRADIENTS[i % 4].from }"
>{{ ring.amount }}</span>
</template>
</div>
</div>
</div>
</div>
</div>
<!-- Detail Card -->
<div
v-if="detailRows.length > 0"
class="fade-up fade-up-delay-3 rounded-2xl border border-gray-200 bg-white/90 backdrop-blur-sm overflow-hidden dark:border-dark-700 dark:bg-dark-900/90"
>
<div class="px-8 py-5 border-b border-gray-200 dark:border-dark-700">
<h3 class="text-sm font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.detailInfo') }}</h3>
</div>
<div class="divide-y divide-gray-100 dark:divide-dark-800">
<div
v-for="(row, i) in detailRows"
:key="i"
class="px-8 py-4 flex items-center justify-between"
>
<div class="flex items-center gap-3">
<div class="w-8 h-8 rounded-lg flex items-center justify-center" :class="row.iconBg">
<svg
class="w-4 h-4"
:class="row.iconColor"
viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
v-html="row.iconSvg"
></svg>
</div>
<span class="text-sm text-gray-700 dark:text-dark-200">{{ row.label }}</span>
</div>
<span class="text-sm font-semibold tabular-nums" :class="row.valueClass || 'text-gray-900 dark:text-white'">
{{ row.value }}
</span>
</div>
</div>
</div>
<!-- Usage Stats Card -->
<div
v-if="usageStatCells.length > 0"
class="fade-up fade-up-delay-3 rounded-2xl border border-gray-200 bg-white/90 backdrop-blur-sm overflow-hidden dark:border-dark-700 dark:bg-dark-900/90"
>
<div class="px-8 py-5 border-b border-gray-200 dark:border-dark-700">
<h3 class="text-sm font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.tokenStats') }}</h3>
</div>
<div class="grid grid-cols-2 md:grid-cols-4 gap-px bg-gray-100 dark:bg-dark-800">
<div
v-for="(cell, i) in usageStatCells"
:key="i"
class="bg-white px-6 py-4 dark:bg-dark-900"
>
<div class="text-xs text-gray-500 dark:text-dark-400 mb-1">{{ cell.label }}</div>
<div class="text-sm font-semibold tabular-nums text-gray-900 dark:text-white">{{ cell.value }}</div>
</div>
</div>
</div>
<!-- Model Stats Table -->
<div
v-if="modelStats.length > 0"
class="fade-up fade-up-delay-4 rounded-2xl border border-gray-200 bg-white/90 backdrop-blur-sm overflow-hidden dark:border-dark-700 dark:bg-dark-900/90"
>
<div class="px-8 py-5 border-b border-gray-200 dark:border-dark-700">
<h3 class="text-sm font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.modelStats') }}</h3>
</div>
<div class="overflow-x-auto">
<table class="w-full">
<thead>
<tr class="border-b border-gray-200 bg-gray-50 dark:border-dark-700 dark:bg-dark-950">
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.model') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.requests') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.inputTokens') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.outputTokens') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cacheCreationTokens') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cacheReadTokens') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.totalTokens') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cost') }}</th>
</tr>
</thead>
<tbody>
<tr
v-for="(m, i) in modelStats"
:key="i"
class="border-b border-gray-100 last:border-b-0 dark:border-dark-800"
>
<td class="px-4 py-3 text-sm font-medium whitespace-nowrap text-gray-900 dark:text-white">{{ m.model || '-' }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(m.requests) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(m.input_tokens) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(m.output_tokens) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(m.cache_creation_tokens) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(m.cache_read_tokens) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(m.total_tokens) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right font-medium text-gray-900 dark:text-white">{{ usd(m.actual_cost != null ? m.actual_cost : m.cost) }}</td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
</div>
</main>
<!-- Footer (same pattern as HomeView) -->
<footer class="relative z-10 border-t border-gray-200/50 px-6 py-8 dark:border-dark-800/50">
<div class="mx-auto flex max-w-6xl flex-col items-center justify-center gap-4 text-center sm:flex-row sm:text-left">
<p class="text-sm text-gray-500 dark:text-dark-400">
&copy; {{ currentYear }} {{ siteName }}. {{ t('home.footer.allRightsReserved') }}
</p>
<div class="flex items-center gap-4">
<a
v-if="docUrl"
:href="docUrl"
target="_blank"
rel="noopener noreferrer"
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
>{{ t('home.docs') }}</a>
<a
:href="githubUrl"
target="_blank"
rel="noopener noreferrer"
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
>GitHub</a>
</div>
</div>
</footer>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted, nextTick } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores'
import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue'
import Icon from '@/components/icons/Icon.vue'
const { t, locale } = useI18n()
const appStore = useAppStore()
// ==================== Site Settings (same as HomeView) ====================
const siteName = computed(() => appStore.cachedPublicSettings?.site_name || appStore.siteName || 'Sub2API')
const siteLogo = computed(() => appStore.cachedPublicSettings?.site_logo || appStore.siteLogo || '')
const docUrl = computed(() => appStore.cachedPublicSettings?.doc_url || appStore.docUrl || '')
const githubUrl = 'https://github.com/Wei-Shaw/sub2api'
// ==================== Theme (same as HomeView) ====================
const isDark = ref(document.documentElement.classList.contains('dark'))
function toggleTheme() {
isDark.value = !isDark.value
document.documentElement.classList.toggle('dark', isDark.value)
localStorage.setItem('theme', isDark.value ? 'dark' : 'light')
}
const currentYear = computed(() => new Date().getFullYear())
// ==================== Key Query State ====================
const apiKey = ref('')
const keyVisible = ref(false)
const isQuerying = ref(false)
const showResults = ref(false)
const showLoading = ref(false)
const showDatePicker = ref(false)
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const resultData = ref<any>(null)
// ==================== Date Range State ====================
type DateRangeKey = 'today' | '7d' | '30d' | 'custom'
const currentRange = ref<DateRangeKey>('today')
const customStartDate = ref('')
const customEndDate = ref('')
const dateRanges = computed(() => [
{ key: 'today' as const, label: t('keyUsage.dateRangeToday') },
{ key: '7d' as const, label: t('keyUsage.dateRange7d') },
{ key: '30d' as const, label: t('keyUsage.dateRange30d') },
{ key: 'custom' as const, label: t('keyUsage.dateRangeCustom') },
])
function setDateRange(key: DateRangeKey) {
currentRange.value = key
if (key !== 'custom') {
queryKey()
}
}
function getDateParams(): string {
const now = new Date()
const fmt = (d: Date) => d.toISOString().split('T')[0]
if (currentRange.value === 'custom') {
if (customStartDate.value && customEndDate.value) {
return `start_date=${customStartDate.value}&end_date=${customEndDate.value}`
}
return ''
}
const end = fmt(now)
let start: string
switch (currentRange.value) {
case 'today': start = end; break
case '7d': start = fmt(new Date(now.getTime() - 7 * 86400000)); break
case '30d': start = fmt(new Date(now.getTime() - 30 * 86400000)); break
default: start = fmt(new Date(now.getTime() - 30 * 86400000))
}
return `start_date=${start}&end_date=${end}`
}
// ==================== Ring Animation ====================
const CIRCUMFERENCE = 2 * Math.PI * 68
const RING_GRADIENTS = [
{ from: '#14b8a6', to: '#5eead4' },
{ from: '#6366F1', to: '#A5B4FC' },
{ from: '#10B981', to: '#6EE7B7' },
{ from: '#F59E0B', to: '#FCD34D' },
]
const ringAnimated = ref(false)
const displayPcts = ref<number[]>([])
const ringTrackColor = computed(() => isDark.value ? '#222222' : '#F0F0EE')
interface RingItem {
title: string
pct: number
amount: string
isBalance?: boolean
iconType: 'clock' | 'calendar' | 'dollar'
}
function getRingOffset(ring: RingItem): number {
if (!ringAnimated.value) return CIRCUMFERENCE
if (ring.isBalance) return 0
return CIRCUMFERENCE - (Math.min(ring.pct, 100) / 100) * CIRCUMFERENCE
}
function triggerRingAnimation(items: RingItem[]) {
ringAnimated.value = false
displayPcts.value = items.map(() => 0)
nextTick(() => {
requestAnimationFrame(() => {
setTimeout(() => {
ringAnimated.value = true
// Animate percentage numbers
const duration = 1000
const startTime = performance.now()
const targets = items.map(item => item.isBalance ? 0 : item.pct)
function tick() {
const elapsed = performance.now() - startTime
const p = Math.min(elapsed / duration, 1)
const ease = 1 - Math.pow(1 - p, 3)
displayPcts.value = targets.map(target => Math.round(ease * target))
if (p < 1) requestAnimationFrame(tick)
}
requestAnimationFrame(tick)
}, 50)
})
})
}
// ==================== Computed Data ====================
const statusInfo = computed(() => {
const data = resultData.value
if (!data) return null
if (data.mode === 'quota_limited') {
const isValid = data.isValid !== false
const statusMap: Record<string, string> = {
active: 'Active',
quota_exhausted: 'Quota Exhausted',
expired: 'Expired',
}
return {
label: t('keyUsage.quotaMode'),
statusText: statusMap[data.status] || data.status || 'Unknown',
isActive: isValid && data.status === 'active',
}
}
return {
label: data.planName || t('keyUsage.walletBalance'),
statusText: 'Active',
isActive: true,
}
})
const ringItems = computed<RingItem[]>(() => {
const data = resultData.value
if (!data) return []
const items: RingItem[] = []
if (data.mode === 'quota_limited') {
if (data.quota) {
const pct = data.quota.limit > 0 ? Math.min(Math.round((data.quota.used / data.quota.limit) * 100), 100) : 0
items.push({ title: t('keyUsage.totalQuota'), pct, amount: `${usd(data.quota.used)} / ${usd(data.quota.limit)}`, iconType: 'dollar' })
}
if (data.rate_limits) {
const windowLabels: Record<string, string> = { '5h': t('keyUsage.limit5h'), '1d': t('keyUsage.limitDaily'), '7d': t('keyUsage.limit7d') }
const windowIcons: Record<string, 'clock' | 'calendar'> = { '5h': 'clock', '1d': 'calendar', '7d': 'calendar' }
for (const rl of data.rate_limits) {
const pct = rl.limit > 0 ? Math.min(Math.round((rl.used / rl.limit) * 100), 100) : 0
items.push({
title: windowLabels[rl.window] || rl.window,
pct,
amount: `${usd(rl.used)} / ${usd(rl.limit)}`,
iconType: windowIcons[rl.window] || 'clock',
})
}
}
} else {
if (data.subscription) {
const sub = data.subscription
const limits = [
{ label: t('keyUsage.limitDaily'), usage: sub.daily_usage_usd, limit: sub.daily_limit_usd },
{ label: t('keyUsage.limitWeekly'), usage: sub.weekly_usage_usd, limit: sub.weekly_limit_usd },
{ label: t('keyUsage.limitMonthly'), usage: sub.monthly_usage_usd, limit: sub.monthly_limit_usd },
]
for (const l of limits) {
if (l.limit != null && l.limit > 0) {
const pct = Math.min(Math.round((l.usage / l.limit) * 100), 100)
items.push({ title: l.label, pct, amount: `${usd(l.usage)} / ${usd(l.limit)}`, iconType: 'calendar' })
}
}
}
if (!data.subscription && data.balance != null) {
items.push({ title: t('keyUsage.walletBalance'), pct: 0, amount: usd(data.balance), isBalance: true, iconType: 'dollar' })
}
}
return items
})
const ringGridClass = computed(() => {
const len = ringItems.value.length
if (len === 1) return 'grid grid-cols-1 max-w-md mx-auto gap-6'
if (len === 2) return 'grid grid-cols-1 md:grid-cols-2 gap-6'
return 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6'
})
interface DetailRow {
iconBg: string
iconColor: string
iconSvg: string
label: string
value: string
valueClass: string
}
function getUsageColor(pct: number): string {
if (pct > 90) return 'text-rose-500'
if (pct > 70) return 'text-amber-500'
return 'text-emerald-500'
}
const detailRows = computed<DetailRow[]>(() => {
const data = resultData.value
if (!data) return []
const rows: DetailRow[] = []
const ICON_SHIELD = '<path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z"/>'
const ICON_CALENDAR = '<rect x="3" y="4" width="18" height="18" rx="2" ry="2"/><line x1="16" y1="2" x2="16" y2="6"/><line x1="8" y1="2" x2="8" y2="6"/><line x1="3" y1="10" x2="21" y2="10"/>'
const ICON_DOLLAR = '<line x1="12" y1="1" x2="12" y2="23"/><path d="M17 5H9.5a3.5 3.5 0 0 0 0 7h5a3.5 3.5 0 0 1 0 7H6"/>'
const ICON_CHECK = '<polyline points="20 6 9 17 4 12"/>'
if (data.mode === 'quota_limited') {
if (data.quota) {
const remainColor = data.quota.remaining <= 0 ? 'text-rose-500'
: data.quota.remaining < data.quota.limit * 0.1 ? 'text-amber-500'
: 'text-emerald-500'
rows.push({
iconBg: 'bg-emerald-500/10', iconColor: 'text-emerald-500', iconSvg: ICON_SHIELD,
label: t('keyUsage.remainingQuota'), value: usd(data.quota.remaining), valueClass: remainColor,
})
}
if (data.expires_at) {
const daysLeft = data.days_until_expiry
let expiryStr = formatDate(data.expires_at)
if (daysLeft != null) {
expiryStr += daysLeft > 0 ? ` ${t('keyUsage.daysLeft', { days: daysLeft })}` : daysLeft === 0 ? ` ${t('keyUsage.todayExpires')}` : ''
}
rows.push({
iconBg: 'bg-amber-500/10', iconColor: 'text-amber-500', iconSvg: ICON_CALENDAR,
label: t('keyUsage.expiresAt'), value: expiryStr, valueClass: '',
})
}
if (data.rate_limits) {
const windowMap: Record<string, string> = { '5h': '5H', '1d': locale.value === 'zh' ? '日' : 'D', '7d': '7D' }
for (const rl of data.rate_limits) {
const pct = rl.limit > 0 ? (rl.used / rl.limit) * 100 : 0
rows.push({
iconBg: 'bg-primary-500/10', iconColor: 'text-primary-500', iconSvg: ICON_DOLLAR,
label: `${t('keyUsage.usedQuota')} (${windowMap[rl.window] || rl.window})`,
value: `${usd(rl.used)} / ${usd(rl.limit)}`,
valueClass: getUsageColor(pct),
})
}
}
} else {
rows.push({
iconBg: 'bg-emerald-500/10', iconColor: 'text-emerald-500', iconSvg: ICON_CHECK,
label: t('keyUsage.subscriptionType'), value: data.planName || t('keyUsage.walletBalance'), valueClass: '',
})
if (data.subscription) {
const sub = data.subscription
if (sub.daily_limit_usd > 0) {
const pct = (sub.daily_usage_usd / sub.daily_limit_usd) * 100
rows.push({
iconBg: 'bg-primary-500/10', iconColor: 'text-primary-500', iconSvg: ICON_DOLLAR,
label: `${t('keyUsage.usedQuota')} (${locale.value === 'zh' ? '日' : 'D'})`, value: `${usd(sub.daily_usage_usd)} / ${usd(sub.daily_limit_usd)}`, valueClass: getUsageColor(pct),
})
}
if (sub.weekly_limit_usd > 0) {
const pct = (sub.weekly_usage_usd / sub.weekly_limit_usd) * 100
rows.push({
iconBg: 'bg-indigo-500/10', iconColor: 'text-indigo-500', iconSvg: ICON_DOLLAR,
label: `${t('keyUsage.usedQuota')} (${locale.value === 'zh' ? '周' : 'W'})`, value: `${usd(sub.weekly_usage_usd)} / ${usd(sub.weekly_limit_usd)}`, valueClass: getUsageColor(pct),
})
}
if (sub.monthly_limit_usd > 0) {
const pct = (sub.monthly_usage_usd / sub.monthly_limit_usd) * 100
rows.push({
iconBg: 'bg-emerald-500/10', iconColor: 'text-emerald-500', iconSvg: ICON_DOLLAR,
label: `${t('keyUsage.usedQuota')} (${locale.value === 'zh' ? '月' : 'M'})`, value: `${usd(sub.monthly_usage_usd)} / ${usd(sub.monthly_limit_usd)}`, valueClass: getUsageColor(pct),
})
}
if (sub.expires_at) {
rows.push({
iconBg: 'bg-amber-500/10', iconColor: 'text-amber-500', iconSvg: ICON_CALENDAR,
label: t('keyUsage.subscriptionExpires'), value: formatDate(sub.expires_at), valueClass: '',
})
}
}
const remainColor = data.remaining != null
? (data.remaining <= 0 ? 'text-rose-500' : data.remaining < 10 ? 'text-amber-500' : 'text-emerald-500')
: ''
rows.push({
iconBg: 'bg-emerald-500/10', iconColor: 'text-emerald-500', iconSvg: ICON_SHIELD,
label: t('keyUsage.remainingQuota'), value: data.remaining != null ? usd(data.remaining) : '-', valueClass: remainColor,
})
}
return rows
})
interface StatCell {
label: string
value: string
}
const usageStatCells = computed<StatCell[]>(() => {
const usage = resultData.value?.usage
if (!usage) return []
const today = usage.today || {}
const total = usage.total || {}
return [
{ label: t('keyUsage.todayRequests'), value: fmtNum(today.requests) },
{ label: t('keyUsage.todayInputTokens'), value: fmtNum(today.input_tokens) },
{ label: t('keyUsage.todayOutputTokens'), value: fmtNum(today.output_tokens) },
{ label: t('keyUsage.todayTokens'), value: fmtNum(today.total_tokens) },
{ label: t('keyUsage.todayCacheCreation'), value: fmtNum(today.cache_creation_tokens) },
{ label: t('keyUsage.todayCacheRead'), value: fmtNum(today.cache_read_tokens) },
{ label: t('keyUsage.todayCost'), value: usd(today.actual_cost) },
{ label: t('keyUsage.rpmTpm'), value: `${usage.rpm || 0} / ${usage.tpm || 0}` },
{ label: t('keyUsage.totalRequests'), value: fmtNum(total.requests) },
{ label: t('keyUsage.totalInputTokens'), value: fmtNum(total.input_tokens) },
{ label: t('keyUsage.totalOutputTokens'), value: fmtNum(total.output_tokens) },
{ label: t('keyUsage.totalTokensLabel'), value: fmtNum(total.total_tokens) },
{ label: t('keyUsage.totalCacheCreation'), value: fmtNum(total.cache_creation_tokens) },
{ label: t('keyUsage.totalCacheRead'), value: fmtNum(total.cache_read_tokens) },
{ label: t('keyUsage.totalCost'), value: usd(total.actual_cost) },
{ label: t('keyUsage.avgDuration'), value: usage.average_duration_ms ? `${Math.round(usage.average_duration_ms)} ms` : '-' },
]
})
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const modelStats = computed<any[]>(() => resultData.value?.model_stats || [])
// ==================== Utility Functions ====================
function usd(value: number | null | undefined): string {
if (value == null || value < 0) return '-'
return '$' + Number(value).toFixed(2)
}
function fmtNum(val: number | null | undefined): string {
if (val == null) return '-'
return val.toLocaleString()
}
function formatDate(iso: string | null | undefined): string {
if (!iso) return '-'
const d = new Date(iso)
const loc = locale.value === 'zh' ? 'zh-CN' : 'en-US'
return d.toLocaleDateString(loc, { year: 'numeric', month: 'long', day: 'numeric' })
}
// ==================== API Query ====================
async function fetchUsage(key: string) {
const dateParams = getDateParams()
const url = '/v1/usage' + (dateParams ? '?' + dateParams : '')
const res = await fetch(url, {
headers: { 'Authorization': 'Bearer ' + key },
})
if (!res.ok) {
const body = await res.json().catch(() => null)
const msg = body?.error?.message || body?.message || `${t('keyUsage.queryFailed')} (${res.status})`
throw new Error(msg)
}
return await res.json()
}
async function queryKey() {
if (isQuerying.value) return
const key = apiKey.value.trim()
if (!key) {
appStore.showInfo(t('keyUsage.enterApiKey'))
return
}
isQuerying.value = true
showResults.value = true
showLoading.value = true
resultData.value = null
try {
const data = await fetchUsage(key)
resultData.value = data
showLoading.value = false
showDatePicker.value = true
// Trigger ring animations after DOM update
nextTick(() => {
triggerRingAnimation(ringItems.value)
})
appStore.showSuccess(t('keyUsage.querySuccess'))
} catch (err) {
showResults.value = false
showLoading.value = false
appStore.showError((err as Error).message || t('keyUsage.queryFailedRetry'))
} finally {
isQuerying.value = false
}
}
// ==================== Lifecycle ====================
function initTheme() {
const savedTheme = localStorage.getItem('theme')
if (savedTheme === 'dark' || (!savedTheme && window.matchMedia('(prefers-color-scheme: dark)').matches)) {
isDark.value = true
document.documentElement.classList.add('dark')
}
}
onMounted(() => {
initTheme()
if (!appStore.publicSettingsLoaded) {
appStore.fetchPublicSettings()
}
})
</script>
<style scoped>
/* Input focus ring */
.input-ring {
transition: box-shadow 0.2s ease, border-color 0.2s ease;
}
.input-ring:focus {
box-shadow: 0 0 0 3px rgba(20, 184, 166, 0.2);
border-color: #14b8a6;
outline: none;
}
/* Ring animation */
.progress-ring {
transition: stroke-dashoffset 1.2s cubic-bezier(0.4, 0, 0.2, 1);
transform: rotate(-90deg);
transform-origin: 50% 50%;
}
/* Skeleton loading */
@keyframes shimmer-kv {
0% { background-position: -200% 0; }
100% { background-position: 200% 0; }
}
.skeleton {
background: linear-gradient(90deg, #e5e7eb 25%, #f3f4f6 50%, #e5e7eb 75%);
background-size: 200% 100%;
animation: shimmer-kv 1.8s ease-in-out infinite;
border-radius: 8px;
}
:global(.dark) .skeleton {
background: linear-gradient(90deg, #334155 25%, #1e293b 50%, #334155 75%);
background-size: 200% 100%;
}
/* Fade up animation */
@keyframes fade-up-kv {
from { opacity: 0; transform: translateY(16px); }
to { opacity: 1; transform: translateY(0); }
}
.fade-up {
animation: fade-up-kv 0.5s cubic-bezier(0.4, 0, 0.2, 1) forwards;
}
.fade-up-delay-1 { animation-delay: 0.1s; opacity: 0; }
.fade-up-delay-2 { animation-delay: 0.2s; opacity: 0; }
.fade-up-delay-3 { animation-delay: 0.3s; opacity: 0; }
.fade-up-delay-4 { animation-delay: 0.4s; opacity: 0; }
/* Pulse dot */
@keyframes pulse-dot-kv {
0%, 100% { opacity: 1; box-shadow: 0 0 0 0 currentColor; }
50% { opacity: 0.6; box-shadow: 0 0 8px 2px currentColor; }
}
.pulse-dot {
animation: pulse-dot-kv 2s ease-in-out infinite;
}
/* Tabular nums */
.tabular-nums {
font-variant-numeric: tabular-nums;
letter-spacing: -0.02em;
}
</style>

View File

@@ -260,7 +260,8 @@
<ReAuthAccountModal :show="showReAuth" :account="reAuthAcc" @close="closeReAuthModal" @reauthorized="handleAccountUpdated" />
<AccountTestModal :show="showTest" :account="testingAcc" @close="closeTestModal" />
<AccountStatsModal :show="showStats" :account="statsAcc" @close="closeStatsModal" />
<AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @reauth="handleReAuth" @refresh-token="handleRefresh" @reset-status="handleResetStatus" @clear-rate-limit="handleClearRateLimit" />
<ScheduledTestsPanel :show="showSchedulePanel" :account-id="scheduleAcc?.id ?? null" :model-options="scheduleModelOptions" @close="closeSchedulePanel" />
<AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @schedule="handleSchedule" @reauth="handleReAuth" @refresh-token="handleRefresh" @reset-status="handleResetStatus" @clear-rate-limit="handleClearRateLimit" />
<SyncFromCrsModal :show="showSync" @close="showSync = false" @synced="reload" />
<ImportDataModal :show="showImportData" @close="showImportData = false" @imported="handleDataImported" />
<BulkEditAccountModal :show="showBulkEdit" :account-ids="selIds" :selected-platforms="selPlatforms" :selected-types="selTypes" :proxies="proxies" :groups="groups" @close="showBulkEdit = false" @updated="handleBulkUpdated" />
@@ -298,6 +299,8 @@ import ImportDataModal from '@/components/admin/account/ImportDataModal.vue'
import ReAuthAccountModal from '@/components/admin/account/ReAuthAccountModal.vue'
import AccountTestModal from '@/components/admin/account/AccountTestModal.vue'
import AccountStatsModal from '@/components/admin/account/AccountStatsModal.vue'
import ScheduledTestsPanel from '@/components/admin/account/ScheduledTestsPanel.vue'
import type { SelectOption } from '@/components/common/Select.vue'
import AccountStatusIndicator from '@/components/account/AccountStatusIndicator.vue'
import AccountUsageCell from '@/components/account/AccountUsageCell.vue'
import AccountTodayStatsCell from '@/components/account/AccountTodayStatsCell.vue'
@@ -307,7 +310,7 @@ import PlatformTypeBadge from '@/components/common/PlatformTypeBadge.vue'
import Icon from '@/components/icons/Icon.vue'
import ErrorPassthroughRulesModal from '@/components/admin/ErrorPassthroughRulesModal.vue'
import { formatDateTime, formatRelativeTime } from '@/utils/format'
import type { Account, AccountPlatform, AccountType, Proxy, AdminGroup, WindowStats } from '@/types'
import type { Account, AccountPlatform, AccountType, Proxy, AdminGroup, WindowStats, ClaudeModel } from '@/types'
const { t } = useI18n()
const appStore = useAppStore()
@@ -351,6 +354,9 @@ const deletingAcc = ref<Account | null>(null)
const reAuthAcc = ref<Account | null>(null)
const testingAcc = ref<Account | null>(null)
const statsAcc = ref<Account | null>(null)
const showSchedulePanel = ref(false)
const scheduleAcc = ref<Account | null>(null)
const scheduleModelOptions = ref<SelectOption[]>([])
const togglingSchedulable = ref<number | null>(null)
const menu = reactive<{show:boolean, acc:Account|null, pos:{top:number, left:number}|null}>({ show: false, acc: null, pos: null })
const exportingData = ref(false)
@@ -546,18 +552,27 @@ const {
handlePageSizeChange: baseHandlePageSizeChange
} = useTableLoader<Account, any>({
fetchFn: adminAPI.accounts.list,
initialParams: { platform: '', type: '', status: '', group: '', search: '', lite: '1' }
initialParams: { platform: '', type: '', status: '', group: '', search: '' }
})
const resetAutoRefreshCache = () => {
autoRefreshETag.value = null
}
const isFirstLoad = ref(true)
const load = async () => {
hasPendingListSync.value = false
resetAutoRefreshCache()
pendingTodayStatsRefresh.value = false
if (isFirstLoad.value) {
;(params as any).lite = '1'
}
await baseLoad()
if (isFirstLoad.value) {
isFirstLoad.value = false
delete (params as any).lite
}
await refreshTodayStatsBatch()
}
@@ -612,6 +627,7 @@ const isAnyModalOpen = computed(() => {
showReAuth.value ||
showTest.value ||
showStats.value ||
showSchedulePanel.value ||
showErrorPassthrough.value
)
})
@@ -689,7 +705,7 @@ const refreshAccountsIncrementally = async () => {
type?: string
status?: string
search?: string
lite?: string
},
{ etag: autoRefreshETag.value }
)
@@ -1067,6 +1083,18 @@ const closeStatsModal = () => { showStats.value = false; statsAcc.value = null }
const closeReAuthModal = () => { showReAuth.value = false; reAuthAcc.value = null }
const handleTest = (a: Account) => { testingAcc.value = a; showTest.value = true }
const handleViewStats = (a: Account) => { statsAcc.value = a; showStats.value = true }
const handleSchedule = async (a: Account) => {
scheduleAcc.value = a
scheduleModelOptions.value = []
showSchedulePanel.value = true
try {
const models = await adminAPI.accounts.getAvailableModels(a.id)
scheduleModelOptions.value = models.map((m: ClaudeModel) => ({ value: m.id, label: m.display_name || m.id }))
} catch {
scheduleModelOptions.value = []
}
}
const closeSchedulePanel = () => { showSchedulePanel.value = false; scheduleAcc.value = null; scheduleModelOptions.value = [] }
const handleReAuth = (a: Account) => { reAuthAcc.value = a; showReAuth.value = true }
const handleRefresh = async (a: Account) => {
try {

View File

@@ -160,7 +160,7 @@ onUnmounted(() => {
.custom-open-fab {
@apply absolute right-3 top-3 z-10;
@apply shadow-sm backdrop-blur supports-[backdrop-filter]:bg-white/80;
@apply shadow-sm backdrop-blur supports-[backdrop-filter]:bg-white/80 dark:supports-[backdrop-filter]:bg-dark-800/80;
}
.custom-embed-frame {

View File

@@ -143,7 +143,7 @@ onUnmounted(() => {
.purchase-open-fab {
@apply absolute right-3 top-3 z-10;
@apply shadow-sm backdrop-blur supports-[backdrop-filter]:bg-white/80;
@apply shadow-sm backdrop-blur supports-[backdrop-filter]:bg-white/80 dark:supports-[backdrop-filter]:bg-dark-800/80;
}
.purchase-embed-frame {

View File

@@ -113,6 +113,9 @@
<!-- Actions -->
<div class="ml-auto flex items-center gap-3">
<button @click="applyFilters" :disabled="loading" class="btn btn-secondary">
{{ t('common.refresh') }}
</button>
<button @click="resetFilters" class="btn btn-secondary">
{{ t('common.reset') }}
</button>

View File

@@ -115,6 +115,10 @@ export default defineConfig(({ mode }) => {
target: backendUrl,
changeOrigin: true
},
'/v1': {
target: backendUrl,
changeOrigin: true
},
'/setup': {
target: backendUrl,
changeOrigin: true