Compare commits

..

26 Commits

Author SHA1 Message Date
shaw
7fd94ab78b fix: 修复usage页面未显示缓存写入的问题 2025-12-19 16:57:31 +08:00
shaw
078529e51e chore: 更新docker的postgres版本为18 2025-12-19 16:42:03 +08:00
shaw
23a4cf11c8 fix: 设置默认logo作为favicon 2025-12-19 16:41:00 +08:00
shaw
d1f0902ec0 feat(account): 支持账号级别拦截预热请求
- 新增 intercept_warmup_requests 配置项,存储在 credentials 字段
- 启用后,标题生成、Warmup 等预热请求返回 mock 响应,不消耗上游 token
- 前端支持所有账号类型(OAuth、Setup Token、API Key)的开关配置
- 修复 OAuth 凭证刷新时丢失非 token 配置的问题
2025-12-19 16:39:25 +08:00
shaw
ee86dbca9d feat(account): 账号测试支持选择模型
- 新增 GET /api/v1/admin/accounts/:id/models 接口获取账号可用模型
- 账号测试弹窗新增模型选择下拉框
- 测试时支持传入 model_id 参数,不传则默认使用 Sonnet
- API Key 账号支持根据 model_mapping 映射测试模型
- 将模型常量提取到 claude 包统一管理
2025-12-19 16:00:09 +08:00
Wesley Liddick
733d4c2b85 Merge pull request #6 from dexcoder6/main
fix(frontend): 修复移动端菜单栏和使用记录页面 UI 问题
2025-12-19 02:59:05 -05:00
dexcoder6
406d3f3cab fix(frontend): 修复移动端菜单栏和使用记录页面 UI 问题
- 修复移动端无法打开菜单栏的问题
  - 在 app.ts 中添加 mobileOpen 状态管理
  - 修复 AppHeader.vue 中移动端菜单按钮调用错误的方法
  - 修复 AppSidebar.vue 使用本地 ref 而非全局状态的问题

- 添加移动端菜单自动关闭功能
  - 点击菜单项后自动关闭侧边栏
  - 添加 150ms 延迟以显示关闭动画

- 修复使用记录页面总消费卡片溢出问题
  - 调整总消费卡片布局,将删除线价格移至说明行
  - 添加 min-w-0 flex-1 防止内容溢出
  - 保持与其他卡片高度一致

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-19 15:55:42 +08:00
shaw
1ed93a5fd0 refactor: 提取 Claude 客户端常量到独立包
- 新增 internal/pkg/claude 包统一管理 Claude Code 相关常量
  - 统一账号测试逻辑,所有账号类型使用相同的 Claude Code 风格请求
  - 网关服务使用常量包替换硬编码的 beta header 字符串
2025-12-19 15:22:52 +08:00
shaw
463ddea36f fix(frontend): 修复代理快捷添加弹窗的 i18n 解析错误
batchInputHint 中的 @ 符号需要使用 {'@'} 转义
2025-12-19 11:24:22 +08:00
shaw
e769f67699 fix(setup): 支持从配置文件读取 Setup Wizard 监听地址
Setup Wizard 之前硬编码使用 8080 端口,现在支持从 config.yaml 或
环境变量 (SERVER_HOST, SERVER_PORT) 读取监听地址,方便用户在端口
被占用时使用其他地址启动初始化向导。
2025-12-19 11:21:58 +08:00
shaw
52d2ae9708 feat(gateway): 添加 /v1/messages/count_tokens 端点
实现 Claude API 的 token 计数功能,支持 OAuth、SetupToken 和 ApiKey 三种账号类型。

特点:
- 校验订阅/余额(不扣费)
- 不计算用户和账号并发
- 不记录使用量
- 支持模型映射(ApiKey 账号)
- 支持 OAuth 账号的指纹管理和 401 重试
2025-12-19 11:12:41 +08:00
shaw
2e59998c51 fix: 代理表单字段保存时自动去除前后空格
前后端同时处理,防止因意外空格导致代理连接失败
2025-12-19 10:39:30 +08:00
shaw
32e58115cc fix(frontend): 修复代理快捷添加弹窗的 i18n 解析错误
转义 batchInputPlaceholder 中的 @ 符号,防止 Vue I18n 将其误解析为链接消息语法
2025-12-19 10:32:22 +08:00
shaw
ba27026399 docs: 调整源码编译步骤的顺序 2025-12-19 09:47:17 +08:00
shaw
c15b419c4c feat(backend): 添加 event_logging 接口直接返回200
将原本在nginx处理的遥测日志请求移至后端,
忽略Claude Code客户端发送的日志数据。
2025-12-19 09:39:57 +08:00
shaw
5bd27a5d17 fix(frontend): 优化分组表单中订阅模式的字段显示逻辑
- 订阅模式下隐藏 Exclusive 字段并默认为开启状态
- 编辑分组时禁用计费类型字段,防止修改
- 移除编辑表单中无用的 subscription_type watch
2025-12-19 08:41:30 +08:00
Wesley Liddick
0e7b8aab8c Merge pull request #4 from NepetaLemon/refactor/backend-wire-provider-sets
refactor(backend): 拆分 Wire ProviderSet
2025-12-18 19:27:49 -05:00
Forest
236908c03d refactor(backend): 拆分 Wire ProviderSet 2025-12-19 00:03:29 +08:00
shaw
67d028cf50 fix: 修复用户修改密码接口404问题
将后端路由与前端API调用对齐:
- /user/profile -> /users/me
- PUT /user/password -> POST /users/me/password
2025-12-18 22:59:49 +08:00
shaw
66ba487697 fix: 修复前端github项目地址 2025-12-18 22:47:42 +08:00
Wesley Liddick
8c7875aa4d Merge pull request #3 from NepetaLemon/refactor/backend-wire-bootstrap
refactor(backend): 引入 Wire 重构服务启动与依赖组装
2025-12-18 09:12:15 -05:00
shaw
145171464f fix: 修复前端多个 bug
1. 版本号闪烁问题
   - 将版本信息缓存到 Pinia store,避免每次路由切换都重新请求
   - 添加加载占位符,版本为空时显示骨架屏

2. 管理员登录跳转问题
   - 管理员登录后现在正确跳转到 /admin/dashboard
   - 普通用户仍跳转到 /dashboard

3. Dashboard 页面空白报错
   - 修复 API 返回 null 时访问 .length 导致的 TypeError
   - 为 computed 属性添加可选链操作符保护
   - 为数据赋值添加空数组默认值
2025-12-18 22:11:29 +08:00
Forest
e5aa676853 refactor(backend): 引入 Wire 重构服务启动与依赖组装 2025-12-18 22:07:17 +08:00
shaw
9b4fc42457 feat: 实现后台在线更新功能
- 前端添加更新和重启按钮,支持一键更新 Release 构建
- 修复条件判断优先级问题,确保错误/成功状态正确显示
- 后端使用原子文件替换模式,确保更新过程安全可靠
- 在可执行文件同目录创建临时文件,保证 rename 原子性
- 删除未使用的 copyFile 函数,保持代码整洁
2025-12-18 21:15:10 +08:00
shaw
caae7e4603 feat: 改进安装脚本的交互体验和自动化流程
- 修复 curl | bash 管道模式下无法交互式输入的问题
  - 使用 /dev/tty 检测终端可用性替代 stdin 检测
  - 所有 read 命令从 /dev/tty 读取用户输入
- 安装完成后自动启动服务和启用开机自启
- 使用 ipinfo.io API 获取公网 IP 用于显示访问地址
- 简化安装完成后的输出信息
2025-12-18 20:53:29 +08:00
shaw
a26db8b3e2 fix: 修复前端页面刷新时偶发空白渲染的竞态条件问题
使用 router.isReady() 等待路由器完成初始导航后再挂载应用,
避免 RouterView 在路由未就绪时渲染空的 Comment 节点。
2025-12-18 20:45:56 +08:00
58 changed files with 2475 additions and 932 deletions

View File

@@ -16,6 +16,14 @@ English | [中文](README_CN.md)
---
## Demo
Try Sub2API online: **https://v2.pincc.ai/**
| Email | Password |
|-------|----------|
| admin@sub2api.com | admin123 |
## Overview
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
@@ -208,20 +216,19 @@ Build and run from source code for development or customization.
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. Build backend
cd backend
go build -o sub2api ./cmd/server
# 3. Build frontend
cd ../frontend
# 2. Build frontend
cd frontend
npm install
npm run build
# 4. Copy frontend build to backend (for embedding)
# 3. Copy frontend build to backend (for embedding)
cp -r dist ../backend/internal/web/
# 5. Create configuration file
# 4. Build backend (requires frontend dist to be present)
cd ../backend
go build -o sub2api ./cmd/server
# 5. Create configuration file
cp ../deploy/config.example.yaml ./config.yaml
# 6. Edit configuration

View File

@@ -16,6 +16,14 @@
---
## 在线体验
体验地址:**https://v2.pincc.ai/**
| 邮箱 | 密码 |
|------|------|
| admin@sub2api.com | admin123 |
## 项目概述
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
@@ -208,20 +216,19 @@ docker-compose logs -f
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. 编译
cd backend
go build -o sub2api ./cmd/server
# 3. 编译前端
cd ../frontend
# 2. 编译
cd frontend
npm install
npm run build
# 4. 复制前端构建产物到后端(用于嵌入)
# 3. 复制前端构建产物到后端(用于嵌入)
cp -r dist ../backend/internal/web/
# 5. 创建配置文件
# 4. 编译后端(需要前端 dist 目录存在)
cd ../backend
go build -o sub2api ./cmd/server
# 5. 创建配置文件
cp ../deploy/config.example.yaml ./config.yaml
# 6. 编辑配置

6
backend/Makefile Normal file
View File

@@ -0,0 +1,6 @@
.PHONY: wire
wire:
@echo "生成 Wire 代码..."
@cd cmd/server && go generate
@echo "Wire 代码生成完成"

View File

@@ -1,8 +1,11 @@
package main
//go:generate go run github.com/google/wire/cmd/wire
import (
"context"
_ "embed"
"errors"
"flag"
"log"
"net/http"
@@ -15,18 +18,10 @@ import (
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"sub2api/internal/setup"
"sub2api/internal/web"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
//go:embed VERSION
@@ -100,8 +95,10 @@ func runSetupServer() {
r.Use(web.ServeEmbeddedFrontend())
}
addr := ":8080"
log.Printf("Setup wizard available at http://localhost%s", addr)
// Get server address from config.yaml or environment variables (SERVER_HOST, SERVER_PORT)
// This allows users to run setup on a different address if needed
addr := config.GetServerAddress()
log.Printf("Setup wizard available at http://%s", addr)
log.Println("Complete the setup wizard to configure Sub2API")
if err := r.Run(addr); err != nil {
@@ -110,78 +107,25 @@ func runSetupServer() {
}
func runMainServer() {
// 加载配置
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 初始化时区(类似 PHP 的 date_default_timezone_set
if err := timezone.Init(cfg.Timezone); err != nil {
log.Fatalf("Failed to initialize timezone: %v", err)
}
// 初始化数据库
db, err := initDB(cfg)
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
// 初始化Redis
rdb := initRedis(cfg)
// 初始化Repository
repos := repository.NewRepositories(db)
// 初始化Service
services := service.NewServices(repos, rdb, cfg)
// 初始化Handler
buildInfo := handler.BuildInfo{
Version: Version,
BuildType: BuildType,
}
handlers := handler.NewHandlers(services, repos, rdb, buildInfo)
// 设置Gin模式
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
}
// 创建路由
r := gin.New()
r.Use(gin.Recovery())
r.Use(middleware.Logger())
r.Use(middleware.CORS())
// 注册路由
registerRoutes(r, handlers, services, repos)
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
app, err := initializeApplication(buildInfo)
if err != nil {
log.Fatalf("Failed to initialize application: %v", err)
}
defer app.Cleanup()
// 启动服务器
srv := &http.Server{
Addr: cfg.Server.Address(),
Handler: r,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
// 优雅关闭
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if err := app.Server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start server: %v", err)
}
}()
log.Printf("Server started on %s", cfg.Server.Address())
log.Printf("Server started on %s", app.Server.Addr)
// 等待中断信号
quit := make(chan os.Signal, 1)
@@ -193,289 +137,9 @@ func runMainServer() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
if err := app.Server.Shutdown(ctx); err != nil {
log.Fatalf("Server forced to shutdown: %v", err)
}
log.Println("Server exited")
}
func initDB(cfg *config.Config) (*gorm.DB, error) {
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
return nil, err
}
return db, nil
}
func initRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
// API v1
v1 := r.Group("/api/v1")
{
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的接口
authenticated := v1.Group("")
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
{
// 当前用户信息
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
// 管理员接口
admin := v1.Group("/admin")
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
{
// 仪表盘
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
// 用户管理
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
// 分组管理
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
// 账号管理
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
// OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
// 卡密管理
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
// 系统设置
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
}
// 系统管理
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
// 订阅管理
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
// 使用记录管理
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
}
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
}
}

103
backend/cmd/server/wire.go Normal file
View File

@@ -0,0 +1,103 @@
//go:build wireinject
// +build wireinject
package main
import (
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"context"
"log"
"net/http"
"time"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
type Application struct {
Server *http.Server
Cleanup func()
}
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
wire.Build(
// 基础设施层 ProviderSets
config.ProviderSet,
infrastructure.ProviderSet,
// 业务层 ProviderSets
repository.ProviderSet,
service.ProviderSet,
handler.ProviderSet,
// 服务器层 ProviderSet
server.ProviderSet,
// 清理函数提供者
provideCleanup,
// 应用程序结构体
wire.Struct(new(Application), "Server", "Cleanup"),
)
return nil, nil
}
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Cleanup steps in reverse dependency order
cleanupSteps := []struct {
name string
fn func() error
}{
{"PricingService", func() error {
services.Pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},
{"Database", func() error {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}},
}
for _, step := range cleanupSteps {
if err := step.fn(); err != nil {
log.Printf("[Cleanup] %s failed: %v", step.name, err)
// Continue with remaining cleanup steps even if one fails
} else {
log.Printf("[Cleanup] %s succeeded", step.name)
}
}
// Check if context timed out
select {
case <-ctx.Done():
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
default:
log.Printf("[Cleanup] All cleanup steps completed")
}
}
}

View File

@@ -0,0 +1,201 @@
// Code generated by Wire. DO NOT EDIT.
//go:generate go run -mod=mod github.com/google/wire/cmd/wire
//go:build !wireinject
// +build !wireinject
package main
import (
"context"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
"log"
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/handler/admin"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"time"
)
import (
_ "embed"
)
// Injectors from wire.go:
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
configConfig, err := config.ProvideConfig()
if err != nil {
return nil, err
}
db, err := infrastructure.ProvideDB(configConfig)
if err != nil {
return nil, err
}
userRepository := repository.NewUserRepository(db)
settingRepository := repository.NewSettingRepository(db)
settingService := service.NewSettingService(settingRepository, configConfig)
client := infrastructure.ProvideRedis(configConfig)
emailService := service.NewEmailService(settingRepository, client)
turnstileService := service.NewTurnstileService(settingService)
emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService)
userService := service.NewUserService(userRepository, configConfig)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, client, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository)
subscriptionService := service.NewSubscriptionService(repositories, billingCacheService)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
adminService := service.NewAdminService(repositories, billingCacheService)
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
oAuthService := service.NewOAuthService(proxyRepository)
rateLimitService := service.NewRateLimitService(repositories, configConfig)
accountUsageService := service.NewAccountUsageService(repositories, oAuthService)
accountTestService := service.NewAccountTestService(repositories, oAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService)
systemHandler := handler.ProvideSystemHandler(client, buildInfo)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
pricingService, err := service.ProvidePricingService(configConfig)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(client)
gatewayService := service.NewGatewayService(repositories, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
concurrencyService := service.NewConcurrencyService(client)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
groupService := service.NewGroupService(groupRepository)
accountService := service.NewAccountService(accountRepository, groupRepository)
proxyService := service.NewProxyService(proxyRepository)
services := &service.Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oAuthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
}
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
httpServer := server.ProvideHTTPServer(configConfig, engine)
v := provideCleanup(db, client, services)
application := &Application{
Server: httpServer,
Cleanup: v,
}
return application, nil
}
// wire.go:
type Application struct {
Server *http.Server
Cleanup func()
}
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cleanupSteps := []struct {
name string
fn func() error
}{
{"PricingService", func() error {
services.Pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},
{"Database", func() error {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}},
}
for _, step := range cleanupSteps {
if err := step.fn(); err != nil {
log.Printf("[Cleanup] %s failed: %v", step.name, err)
} else {
log.Printf("[Cleanup] %s succeeded", step.name)
}
}
select {
case <-ctx.Done():
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
default:
log.Printf("[Cleanup] All cleanup steps completed")
}
}
}

View File

@@ -13,6 +13,7 @@ require (
github.com/redis/go-redis/v9 v9.3.0
github.com/spf13/viper v1.18.2
golang.org/x/crypto v0.44.0
golang.org/x/net v0.47.0
golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.4
@@ -33,6 +34,8 @@ require (
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/google/wire v0.7.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -50,6 +53,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
@@ -66,9 +70,11 @@ require (
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/mod v0.29.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)

View File

@@ -48,8 +48,12 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
@@ -154,8 +158,12 @@ golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
@@ -166,6 +174,8 @@ golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=

View File

@@ -203,3 +203,29 @@ func (c *Config) Validate() error {
}
return nil
}
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.
// Priority: config.yaml > environment variables > defaults
func GetServerAddress() string {
v := viper.New()
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
v.AddConfigPath("/etc/sub2api")
// Support SERVER_HOST and SERVER_PORT environment variables
v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("server.port", 8080)
// Try to read config file (ignore errors if not found)
_ = v.ReadInConfig()
host := v.GetString("server.host")
port := v.GetInt("server.port")
return fmt.Sprintf("%s:%d", host, port)
}

View File

@@ -0,0 +1,13 @@
package config
import "github.com/google/wire"
// ProviderSet 提供配置层的依赖
var ProviderSet = wire.NewSet(
ProvideConfig,
)
// ProvideConfig 提供应用配置
func ProvideConfig() (*Config, error) {
return Load()
}

View File

@@ -3,6 +3,7 @@ package admin
import (
"strconv"
"sub2api/internal/pkg/claude"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
@@ -186,6 +187,11 @@ func (h *AccountHandler) Delete(c *gin.Context) {
response.Success(c, gin.H{"message": "Account deleted successfully"})
}
// TestAccountRequest represents the request body for testing an account
type TestAccountRequest struct {
ModelID string `json:"model_id"`
}
// Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) {
@@ -195,8 +201,12 @@ func (h *AccountHandler) Test(c *gin.Context) {
return
}
var req TestAccountRequest
// Allow empty body, model_id is optional
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID); err != nil {
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
// Error already sent via SSE, just log
return
}
@@ -231,16 +241,20 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return
}
// Update account credentials
newCredentials := map[string]interface{}{
"access_token": tokenInfo.AccessToken,
"token_type": tokenInfo.TokenType,
"expires_in": tokenInfo.ExpiresIn,
"expires_at": tokenInfo.ExpiresAt,
"refresh_token": tokenInfo.RefreshToken,
"scope": tokenInfo.Scope,
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
newCredentials := make(map[string]interface{})
for k, v := range account.Credentials {
newCredentials[k] = v
}
// Update token-related fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
newCredentials["refresh_token"] = tokenInfo.RefreshToken
newCredentials["scope"] = tokenInfo.Scope
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
@@ -535,3 +549,58 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
response.Success(c, account)
}
// GetAvailableModels handles getting available models for an account
// GET /api/v1/admin/accounts/:id/models
func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
response.Success(c, claude.DefaultModels)
return
}
// For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
// No mapping configured, return default models
response.Success(c, claude.DefaultModels)
return
}
// Return mapped models (keys of the mapping are the available model IDs)
var models []claude.Model
for requestedModel := range mapping {
// Try to find display info from default models
var found bool
for _, dm := range claude.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
// If not found in defaults, create a basic entry
if !found {
models = append(models, claude.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
CreatedAt: "",
})
}
}
response.Success(c, models)
}

View File

@@ -2,6 +2,7 @@ package admin
import (
"strconv"
"strings"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
@@ -112,12 +113,12 @@ func (h *ProxyHandler) Create(c *gin.Context) {
}
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: req.Username,
Password: req.Password,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
})
if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error())
@@ -143,13 +144,13 @@ func (h *ProxyHandler) Update(c *gin.Context) {
}
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: req.Username,
Password: req.Password,
Status: req.Status,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
Status: strings.TrimSpace(req.Status),
})
if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error())
@@ -263,8 +264,14 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
skipped := 0
for _, item := range req.Proxies {
// Trim all string fields
host := strings.TrimSpace(item.Host)
protocol := strings.TrimSpace(item.Protocol)
username := strings.TrimSpace(item.Username)
password := strings.TrimSpace(item.Password)
// Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), item.Host, item.Port, item.Username, item.Password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
return
@@ -278,11 +285,11 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
// Create proxy with default name
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: "default",
Protocol: item.Protocol,
Host: item.Host,
Protocol: protocol,
Host: host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
Username: username,
Password: password,
})
if err != nil {
// If creation fails due to duplicate, count as skipped

View File

@@ -7,10 +7,12 @@ import (
"io"
"log"
"net/http"
"strings"
"time"
"sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/pkg/claude"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -126,6 +128,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
} else {
sendMockWarmupResponse(c, req.Model)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
if err != nil {
@@ -285,29 +297,8 @@ func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id
// Models handles listing available models
// GET /v1/models
func (h *GatewayHandler) Models(c *gin.Context) {
models := []gin.H{
{
"id": "claude-opus-4-5-20251101",
"type": "model",
"display_name": "Claude Opus 4.5",
"created_at": "2025-11-01T00:00:00Z",
},
{
"id": "claude-sonnet-4-5-20250929",
"type": "model",
"display_name": "Claude Sonnet 4.5",
"created_at": "2025-09-29T00:00:00Z",
},
{
"id": "claude-haiku-4-5-20251001",
"type": "model",
"display_name": "Claude Haiku 4.5",
"created_at": "2025-10-01T00:00:00Z",
},
}
c.JSON(http.StatusOK, gin.H{
"data": models,
"data": claude.DefaultModels,
"object": "list",
})
}
@@ -443,3 +434,155 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
},
})
}
// CountTokens handles token counting endpoint
// POST /v1/messages/count_tokens
// 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// 解析请求获取模型名
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 获取订阅信息可能为nil
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
return
}
// 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理
return
}
}
// isWarmupRequest 检测是否为预热请求标题生成、Warmup等
func isWarmupRequest(body []byte) bool {
// 快速检查如果body不包含关键字直接返回false
bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
return false
}
// 解析完整请求
var req struct {
Messages []struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
} `json:"messages"`
System []struct {
Text string `json:"text"`
} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
return false
}
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return true
}
}
}
}
// 检查 system 中的标题提取模式
for _, system := range req.System {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return true
}
}
return false
}
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
events := []string{
`event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`,
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
}
for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush()
time.Sleep(20 * time.Millisecond)
}
}
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) {
c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup",
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,
"output_tokens": 2,
},
})
}

View File

@@ -2,10 +2,6 @@ package handler
import (
"sub2api/internal/handler/admin"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// AdminHandlers contains all admin-related HTTP handlers
@@ -41,30 +37,3 @@ type BuildInfo struct {
Version string
BuildType string // "source" for manual builds, "release" for CI builds
}
// NewHandlers creates a new Handlers instance with all handlers initialized
func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers {
return &Handlers{
Auth: NewAuthHandler(services.Auth),
User: NewUserHandler(services.User),
APIKey: NewAPIKeyHandler(services.ApiKey),
Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey),
Redeem: NewRedeemHandler(services.Redeem),
Subscription: NewSubscriptionHandler(services.Subscription),
Admin: &AdminHandlers{
Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog),
User: admin.NewUserHandler(services.Admin),
Group: admin.NewGroupHandler(services.Admin),
Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest),
OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin),
Proxy: admin.NewProxyHandler(services.Admin),
Redeem: admin.NewRedeemHandler(services.Admin),
Setting: admin.NewSettingHandler(services.Setting, services.Email),
System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType),
Subscription: admin.NewSubscriptionHandler(services.Subscription),
Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin),
},
Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache),
Setting: NewSettingHandler(services.Setting, buildInfo.Version),
}
}

View File

@@ -0,0 +1,103 @@
package handler
import (
"sub2api/internal/handler/admin"
"sub2api/internal/service"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
)
// ProvideAdminHandlers creates the AdminHandlers struct
func ProvideAdminHandlers(
dashboardHandler *admin.DashboardHandler,
userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
oauthHandler *admin.OAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
settingHandler *admin.SettingHandler,
systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
User: userHandler,
Group: groupHandler,
Account: accountHandler,
OAuth: oauthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Setting: settingHandler,
System: systemHandler,
Subscription: subscriptionHandler,
Usage: usageHandler,
}
}
// ProvideSystemHandler creates admin.SystemHandler with BuildInfo parameters
func ProvideSystemHandler(rdb *redis.Client, buildInfo BuildInfo) *admin.SystemHandler {
return admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType)
}
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
return NewSettingHandler(settingService, buildInfo.Version)
}
// ProvideHandlers creates the Handlers struct
func ProvideHandlers(
authHandler *AuthHandler,
userHandler *UserHandler,
apiKeyHandler *APIKeyHandler,
usageHandler *UsageHandler,
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
settingHandler *SettingHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
User: userHandler,
APIKey: apiKeyHandler,
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
Setting: settingHandler,
}
}
// ProviderSet is the Wire provider set for all handlers
var ProviderSet = wire.NewSet(
// Top-level handlers
NewAuthHandler,
NewUserHandler,
NewAPIKeyHandler,
NewUsageHandler,
NewRedeemHandler,
NewSubscriptionHandler,
NewGatewayHandler,
ProvideSettingHandler,
// Admin handlers
admin.NewDashboardHandler,
admin.NewUserHandler,
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewSettingHandler,
ProvideSystemHandler,
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
ProvideHandlers,
)

View File

@@ -0,0 +1,38 @@
package infrastructure
import (
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// InitDB 初始化数据库连接
func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 初始化时区(在数据库连接之前,确保时区设置正确)
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, err
}
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,16 @@
package infrastructure
import (
"sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}

View File

@@ -0,0 +1,25 @@
package infrastructure
import (
"sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
// ProviderSet 提供基础设施层的依赖
var ProviderSet = wire.NewSet(
ProvideDB,
ProvideRedis,
)
// ProvideDB 提供数据库连接
func ProvideDB(cfg *config.Config) (*gorm.DB, error) {
return InitDB(cfg)
}
// ProvideRedis 提供 Redis 客户端
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}

View File

@@ -263,3 +263,17 @@ func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
}
return false
}
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
// 启用后标题生成、Warmup等预热请求将返回mock响应不消耗上游token
func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil {
return false
}
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}

View File

@@ -0,0 +1,74 @@
package claude
// Claude Code 客户端相关常量
// Beta header 常量
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0",
"X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
// Model 表示一个 Claude 模型
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels Claude Code 客户端支持的默认模型列表
var DefaultModels = []Model{
{
ID: "claude-opus-4-5-20251101",
Type: "model",
DisplayName: "Claude Opus 4.5",
CreatedAt: "2025-11-01T00:00:00Z",
},
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",
DisplayName: "Claude Sonnet 4.5",
CreatedAt: "2025-09-29T00:00:00Z",
},
{
ID: "claude-haiku-4-5-20251001",
Type: "model",
DisplayName: "Claude Haiku 4.5",
CreatedAt: "2025-10-01T00:00:00Z",
},
}
// DefaultModelIDs 返回默认模型的 ID 列表
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"

View File

@@ -1,9 +1,5 @@
package repository
import (
"gorm.io/gorm"
)
// Repositories 所有仓库的集合
type Repositories struct {
User *UserRepository
@@ -17,21 +13,6 @@ type Repositories struct {
UserSubscription *UserSubscriptionRepository
}
// NewRepositories 创建所有仓库
func NewRepositories(db *gorm.DB) *Repositories {
return &Repositories{
User: NewUserRepository(db),
ApiKey: NewApiKeyRepository(db),
Group: NewGroupRepository(db),
Account: NewAccountRepository(db),
Proxy: NewProxyRepository(db),
RedeemCode: NewRedeemCodeRepository(db),
UsageLog: NewUsageLogRepository(db),
Setting: NewSettingRepository(db),
UserSubscription: NewUserSubscriptionRepository(db),
}
}
// PaginationParams 分页参数
type PaginationParams struct {
Page int

View File

@@ -0,0 +1,19 @@
package repository
import (
"github.com/google/wire"
)
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewApiKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewProxyRepository,
NewRedeemCodeRepository,
NewUsageLogRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
wire.Struct(new(Repositories), "*"),
)

View File

@@ -0,0 +1,45 @@
package server
import (
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/repository"
"sub2api/internal/service"
"time"
"github.com/gin-gonic/gin"
"github.com/google/wire"
)
// ProviderSet 提供服务器层的依赖
var ProviderSet = wire.NewSet(
ProvideRouter,
ProvideHTTPServer,
)
// ProvideRouter 提供路由器
func ProvideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
}
r := gin.New()
r.Use(gin.Recovery())
return SetupRouter(r, cfg, handlers, services, repos)
}
// ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
return &http.Server{
Addr: cfg.Server.Address(),
Handler: router,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
}

View File

@@ -0,0 +1,289 @@
package server
import (
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/middleware"
"sub2api/internal/repository"
"sub2api/internal/service"
"sub2api/internal/web"
"github.com/gin-gonic/gin"
)
// SetupRouter 配置路由器中间件和路由
func SetupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
// 应用中间件
r.Use(middleware.Logger())
r.Use(middleware.CORS())
// 注册路由
registerRoutes(r, handlers, services, repos)
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
}
return r
}
// registerRoutes 注册所有 HTTP 路由
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Claude Code 遥测日志忽略直接返回200
r.POST("/api/event_logging/batch", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
// API v1
v1 := r.Group("/api/v1")
{
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的接口
authenticated := v1.Group("")
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
{
// 当前用户信息
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
// 管理员接口
admin := v1.Group("/admin")
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
{
// 仪表盘
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
// 用户管理
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
// 分组管理
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
// 账号管理
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
// OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
// 卡密管理
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
// 系统设置
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
}
// 系统管理
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
// 订阅管理
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
// 使用记录管理
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
}
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
}
}

View File

@@ -15,6 +15,7 @@ import (
"strings"
"time"
"sub2api/internal/pkg/claude"
"sub2api/internal/repository"
"github.com/gin-gonic/gin"
@@ -23,7 +24,6 @@ import (
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testModel = "claude-sonnet-4-5-20250929"
)
// TestEvent represents a SSE event for account testing
@@ -62,10 +62,10 @@ func generateSessionString() string {
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
}
// createTestPayload creates a minimal test request payload for OAuth/Setup Token accounts
func createTestPayload() map[string]interface{} {
// createTestPayload creates a Claude Code style test request payload
func createTestPayload(modelID string) map[string]interface{} {
return map[string]interface{}{
"model": testModel,
"model": modelID,
"messages": []map[string]interface{}{
{
"role": "user",
@@ -92,29 +92,16 @@ func createTestPayload() map[string]interface{} {
"metadata": map[string]string{
"user_id": generateSessionString(),
},
"max_tokens": 1024,
"max_tokens": 1024,
"temperature": 1,
"stream": true,
}
}
// createApiKeyTestPayload creates a simpler test request payload for API Key accounts
func createApiKeyTestPayload(model string) map[string]interface{} {
return map[string]interface{}{
"model": model,
"messages": []map[string]interface{}{
{
"role": "user",
"content": "hi",
},
},
"max_tokens": 1024,
"stream": true,
"stream": true,
}
}
// TestAccountConnection tests an account's connection by sending a test request
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64) error {
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
ctx := c.Request.Context()
// Get account
@@ -123,14 +110,30 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.sendErrorAndEnd(c, "Account not found")
}
// Determine authentication method based on account type
// Determine the model to use
testModelID := modelID
if testModelID == "" {
testModelID = claude.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if mapping != nil && len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
}
// Determine authentication method and API URL
var authToken string
var authType string // "bearer" for OAuth, "apikey" for API Key
var useBearer bool
var apiURL string
if account.IsOAuth() {
// OAuth or Setup Token account
authType = "bearer"
// OAuth or Setup Token - use Bearer token
useBearer = true
apiURL = testClaudeAPIURL
authToken = account.GetCredential("access_token")
if authToken == "" {
@@ -141,7 +144,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
needRefresh := false
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+300 > expiresAt { // 5 minute buffer
if err == nil && time.Now().Unix()+300 > expiresAt {
needRefresh = true
}
}
@@ -154,19 +157,17 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
authToken = tokenInfo.AccessToken
}
} else if account.Type == "apikey" {
// API Key account
authType = "apikey"
// API Key - use x-api-key header
useBearer = false
authToken = account.GetCredential("api_key")
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
// Get base URL (use default if not set)
apiURL = account.GetBaseURL()
if apiURL == "" {
apiURL = "https://api.anthropic.com"
}
// Append /v1/messages endpoint
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
@@ -179,37 +180,32 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create test request payload
var payload map[string]interface{}
var actualModel string
if authType == "apikey" {
// Use simpler payload for API Key (without Claude Code specific fields)
// Apply model mapping if configured
actualModel = account.GetMappedModel(testModel)
payload = createApiKeyTestPayload(actualModel)
} else {
actualModel = testModel
payload = createTestPayload()
}
// Create Claude Code style payload (same for all account types)
payload := createTestPayload(testModelID)
payloadBytes, _ := json.Marshal(payload)
// Send test_start event with model info
s.sendEvent(c, TestEvent{Type: "test_start", Model: actualModel})
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
// Set headers based on auth type
// Set common headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
if authType == "bearer" {
// Apply Claude Code client headers
for key, value := range claude.DefaultHeaders {
req.Header.Set(key, value)
}
// Set authentication header
if useBearer {
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31,interleaved-thinking-2025-05-14,output-128k-2025-02-19")
} else {
// API Key uses x-api-key header
req.Header.Set("x-api-key", authToken)
}
@@ -252,7 +248,6 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
// Stream ended, send complete event
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
@@ -310,5 +305,5 @@ func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
log.Printf("Account test error: %s", errorMsg)
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
return fmt.Errorf(errorMsg)
return fmt.Errorf("%s", errorMsg)
}

View File

@@ -191,24 +191,17 @@ type adminServiceImpl struct {
}
// NewAdminService creates a new AdminService
func NewAdminService(repos *repository.Repositories) AdminService {
func NewAdminService(repos *repository.Repositories, billingCacheService *BillingCacheService) AdminService {
return &adminServiceImpl{
userRepo: repos.User,
groupRepo: repos.Group,
accountRepo: repos.Account,
proxyRepo: repos.Proxy,
apiKeyRepo: repos.ApiKey,
redeemCodeRepo: repos.RedeemCode,
usageLogRepo: repos.UsageLog,
userSubRepo: repos.UserSubscription,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
// 注意AdminService是接口需要类型断言
func SetAdminServiceBillingCache(adminService AdminService, billingCacheService *BillingCacheService) {
if impl, ok := adminService.(*adminServiceImpl); ok {
impl.billingCacheService = billingCacheService
userRepo: repos.User,
groupRepo: repos.Group,
accountRepo: repos.Account,
proxyRepo: repos.Proxy,
apiKeyRepo: repos.ApiKey,
redeemCodeRepo: repos.RedeemCode,
usageLogRepo: repos.UsageLog,
userSubRepo: repos.UserSubscription,
billingCacheService: billingCacheService,
}
}

View File

@@ -16,13 +16,13 @@ import (
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotActive = errors.New("user is not active")
ErrEmailExists = errors.New("email already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled")
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotActive = errors.New("user is not active")
ErrEmailExists = errors.New("email already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled")
)
// JWTClaims JWT载荷数据
@@ -44,33 +44,24 @@ type AuthService struct {
}
// NewAuthService 创建认证服务实例
func NewAuthService(userRepo *repository.UserRepository, cfg *config.Config) *AuthService {
func NewAuthService(
userRepo *repository.UserRepository,
cfg *config.Config,
settingService *SettingService,
emailService *EmailService,
turnstileService *TurnstileService,
emailQueueService *EmailQueueService,
) *AuthService {
return &AuthService{
userRepo: userRepo,
cfg: cfg,
userRepo: userRepo,
cfg: cfg,
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
emailQueueService: emailQueueService,
}
}
// SetSettingService 设置系统设置服务(用于检查注册开关和邮件验证)
func (s *AuthService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
}
// SetEmailService 设置邮件服务(用于邮件验证)
func (s *AuthService) SetEmailService(emailService *EmailService) {
s.emailService = emailService
}
// SetTurnstileService 设置Turnstile服务用于验证码校验
func (s *AuthService) SetTurnstileService(turnstileService *TurnstileService) {
s.turnstileService = turnstileService
}
// SetEmailQueueService 设置邮件队列服务(用于异步发送邮件)
func (s *AuthService) SetEmailQueueService(emailQueueService *EmailQueueService) {
s.emailQueueService = emailQueueService
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
return s.RegisterWithVerification(ctx, email, password, "")

View File

@@ -20,6 +20,7 @@ import (
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/claude"
"sub2api/internal/repository"
"github.com/gin-gonic/gin"
@@ -27,10 +28,11 @@ import (
)
const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
stickySessionPrefix = "sticky_session:"
stickySessionTTL = time.Hour // 粘性会话TTL
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionPrefix = "sticky_session:"
stickySessionTTL = time.Hour // 粘性会话TTL
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
)
// allowedHeaders 白名单headers参考CRS项目
@@ -601,13 +603,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string {
const oauthBeta = "oauth-2025-04-20"
const claudeCodeBeta = "claude-code-20250219"
// 如果客户端传了anthropic-beta
if clientBetaHeader != "" {
// 已包含oauth beta则直接返回
if strings.Contains(clientBetaHeader, oauthBeta) {
if strings.Contains(clientBetaHeader, claude.BetaOAuth) {
return clientBetaHeader
}
@@ -620,7 +619,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
// 在claude-code-20250219后面插入oauth beta
claudeCodeIdx := -1
for i, p := range parts {
if p == claudeCodeBeta {
if p == claude.BetaClaudeCode {
claudeCodeIdx = i
break
}
@@ -630,13 +629,13 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
// 在claude-code后面插入
newParts := make([]string, 0, len(parts)+1)
newParts = append(newParts, parts[:claudeCodeIdx+1]...)
newParts = append(newParts, oauthBeta)
newParts = append(newParts, claude.BetaOAuth)
newParts = append(newParts, parts[claudeCodeIdx+1:]...)
return strings.Join(newParts, ",")
}
// 没有claude-code放在第一位
return oauthBeta + "," + clientBetaHeader
return claude.BetaOAuth + "," + clientBetaHeader
}
// 客户端没传,根据模型生成
@@ -650,10 +649,10 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
// haiku模型不需要claude-code beta
if strings.Contains(strings.ToLower(modelID), "haiku") {
return "oauth-2025-04-20,interleaved-thinking-2025-05-14"
return claude.HaikuBetaHeader
}
return "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14"
return claude.DefaultBetaHeader
}
func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) {
@@ -1044,3 +1043,205 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error {
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == model.AccountTypeApiKey {
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
body = s.replaceModelInBody(body, mappedModel)
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
}
}
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token")
return err
}
// 构建上游请求
upstreamResult, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
}
// 选择 HTTP client
httpClient := s.httpClient
if upstreamResult.Client != nil {
httpClient = upstreamResult.Client
}
// 发送请求
resp, err := httpClient.Do(upstreamResult.Request)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
}
defer resp.Body.Close()
// 处理 401 错误:刷新 token 重试(仅 OAuth
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
resp.Body.Close()
token, tokenType, err = s.forceRefreshToken(ctx, account)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed")
return fmt.Errorf("token refresh failed: %w", err)
}
upstreamResult, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
if err != nil {
return err
}
httpClient = s.httpClient
if upstreamResult.Client != nil {
httpClient = upstreamResult.Client
}
resp, err = httpClient.Do(upstreamResult.Request)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed")
return fmt.Errorf("retry request failed: %w", err)
}
defer resp.Body.Close()
}
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
// 处理错误响应
if resp.StatusCode >= 400 {
// 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 返回简化的错误响应
errMsg := "Upstream request failed"
switch resp.StatusCode {
case 429:
errMsg = "Rate limit exceeded"
case 529:
errMsg = "Service overloaded"
}
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
return fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// 透传成功响应
c.Data(resp.StatusCode, "application/json", respBody)
return nil
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == model.AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
// OAuth 账号:应用统一指纹和重写 userID
if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
}
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("Authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
// 白名单透传 headers
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
}
}
// 确保必要的 headers 存在
if req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
}
// 配置代理
var customClient *http.Client
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
customClient = &http.Client{Transport: transport}
}
}
}
return &buildUpstreamRequestResult{
Request: req,
Client: customClient,
}, nil
}
// countTokensError 返回 count_tokens 错误响应
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -57,20 +57,22 @@ type RedeemService struct {
}
// NewRedeemService 创建兑换码服务实例
func NewRedeemService(redeemRepo *repository.RedeemCodeRepository, userRepo *repository.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client) *RedeemService {
func NewRedeemService(
redeemRepo *repository.RedeemCodeRepository,
userRepo *repository.UserRepository,
subscriptionService *SubscriptionService,
rdb *redis.Client,
billingCacheService *BillingCacheService,
) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
rdb: rdb,
billingCacheService: billingCacheService,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *RedeemService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
}
// GenerateRandomCode 生成随机兑换码
func (s *RedeemService) GenerateRandomCode() (string, error) {
// 生成16字节随机数据

View File

@@ -1,12 +1,5 @@
package service
import (
"sub2api/internal/config"
"sub2api/internal/repository"
"github.com/redis/go-redis/v9"
)
// Services 服务集合容器
type Services struct {
Auth *AuthService
@@ -34,106 +27,3 @@ type Services struct {
Concurrency *ConcurrencyService
Identity *IdentityService
}
// NewServices 创建所有服务实例
func NewServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *Services {
// 初始化价格服务
pricingService := NewPricingService(cfg)
if err := pricingService.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error())
}
// 初始化计费服务(依赖价格服务)
billingService := NewBillingService(cfg, pricingService)
// 初始化其他服务
authService := NewAuthService(repos.User, cfg)
userService := NewUserService(repos.User, cfg)
apiKeyService := NewApiKeyService(repos.ApiKey, repos.User, repos.Group, repos.UserSubscription, rdb, cfg)
groupService := NewGroupService(repos.Group)
accountService := NewAccountService(repos.Account, repos.Group)
proxyService := NewProxyService(repos.Proxy)
usageService := NewUsageService(repos.UsageLog, repos.User)
// 初始化订阅服务 (RedeemService 依赖)
subscriptionService := NewSubscriptionService(repos)
// 初始化兑换服务 (依赖订阅服务)
redeemService := NewRedeemService(repos.RedeemCode, repos.User, subscriptionService, rdb)
// 初始化Admin服务
adminService := NewAdminService(repos)
// 初始化OAuth服务GatewayService依赖
oauthService := NewOAuthService(repos.Proxy)
// 初始化限流服务
rateLimitService := NewRateLimitService(repos, cfg)
// 初始化计费缓存服务
billingCacheService := NewBillingCacheService(rdb, repos.User, repos.UserSubscription)
// 初始化账号使用量服务
accountUsageService := NewAccountUsageService(repos, oauthService)
// 初始化账号测试服务
accountTestService := NewAccountTestService(repos, oauthService)
// 初始化身份指纹服务
identityService := NewIdentityService(rdb)
// 初始化Gateway服务
gatewayService := NewGatewayService(repos, rdb, cfg, oauthService, billingService, rateLimitService, billingCacheService, identityService)
// 初始化设置服务
settingService := NewSettingService(repos.Setting, cfg)
emailService := NewEmailService(repos.Setting, rdb)
// 初始化邮件队列服务
emailQueueService := NewEmailQueueService(emailService, 3)
// 初始化Turnstile服务
turnstileService := NewTurnstileService(settingService)
// 设置Auth服务的依赖用于注册开关和邮件验证
authService.SetSettingService(settingService)
authService.SetEmailService(emailService)
authService.SetTurnstileService(turnstileService)
authService.SetEmailQueueService(emailQueueService)
// 初始化并发控制服务
concurrencyService := NewConcurrencyService(rdb)
// 注入计费缓存服务到需要失效缓存的服务
redeemService.SetBillingCacheService(billingCacheService)
subscriptionService.SetBillingCacheService(billingCacheService)
SetAdminServiceBillingCache(adminService, billingCacheService)
return &Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oauthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
}
}

View File

@@ -28,13 +28,11 @@ type SubscriptionService struct {
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(repos *repository.Repositories) *SubscriptionService {
return &SubscriptionService{repos: repos}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *SubscriptionService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
func NewSubscriptionService(repos *repository.Repositories, billingCacheService *BillingCacheService) *SubscriptionService {
return &SubscriptionService{
repos: repos,
billingCacheService: billingCacheService,
}
}
// AssignSubscriptionInput 分配订阅输入
@@ -88,6 +86,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
// 如果用户已有同分组的订阅:
// - 未过期:从当前过期时间累加天数
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
//
// 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
@@ -191,15 +190,15 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass
now := time.Now()
sub := &model.UserSubscription{
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: now.AddDate(0, 0, validityDays),
Status: model.SubscriptionStatusActive,
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: now.AddDate(0, 0, validityDays),
Status: model.SubscriptionStatusActive,
AssignedAt: now,
Notes: input.Notes,
CreatedAt: now,
UpdatedAt: now,
Notes: input.Notes,
CreatedAt: now,
UpdatedAt: now,
}
// 只有当 AssignedBy > 0 时才设置0 表示系统分配,如兑换码)
if input.AssignedBy > 0 {
@@ -225,17 +224,17 @@ type BulkAssignSubscriptionInput struct {
// BulkAssignResult 批量分配结果
type BulkAssignResult struct {
SuccessCount int
FailedCount int
SuccessCount int
FailedCount int
Subscriptions []model.UserSubscription
Errors []string
Errors []string
}
// BulkAssignSubscription 批量分配订阅
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
result := &BulkAssignResult{
Subscriptions: make([]model.UserSubscription, 0),
Errors: make([]string, 0),
Errors: make([]string, 0),
}
for _, userID := range input.UserIDs {
@@ -417,10 +416,10 @@ func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID in
// SubscriptionProgress 订阅进度
type SubscriptionProgress struct {
ID int64 `json:"id"`
GroupName string `json:"group_name"`
ExpiresAt time.Time `json:"expires_at"`
ExpiresInDays int `json:"expires_in_days"`
ID int64 `json:"id"`
GroupName string `json:"group_name"`
ExpiresAt time.Time `json:"expires_at"`
ExpiresInDays int `json:"expires_in_days"`
Daily *UsageWindowProgress `json:"daily,omitempty"`
Weekly *UsageWindowProgress `json:"weekly,omitempty"`
Monthly *UsageWindowProgress `json:"monthly,omitempty"`
@@ -428,13 +427,13 @@ type SubscriptionProgress struct {
// UsageWindowProgress 使用窗口进度
type UsageWindowProgress struct {
LimitUSD float64 `json:"limit_usd"`
UsedUSD float64 `json:"used_usd"`
RemainingUSD float64 `json:"remaining_usd"`
Percentage float64 `json:"percentage"`
WindowStart time.Time `json:"window_start"`
ResetsAt time.Time `json:"resets_at"`
ResetsInSeconds int64 `json:"resets_in_seconds"`
LimitUSD float64 `json:"limit_usd"`
UsedUSD float64 `json:"used_usd"`
RemainingUSD float64 `json:"remaining_usd"`
Percentage float64 `json:"percentage"`
WindowStart time.Time `json:"window_start"`
ResetsAt time.Time `json:"resets_at"`
ResetsInSeconds int64 `json:"resets_in_seconds"`
}
// GetSubscriptionProgress 获取订阅使用进度
@@ -464,12 +463,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
limit := *group.DailyLimitUSD
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
progress.Daily = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD,
RemainingUSD: limit - sub.DailyUsageUSD,
Percentage: (sub.DailyUsageUSD / limit) * 100,
WindowStart: *sub.DailyWindowStart,
ResetsAt: resetsAt,
LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD,
RemainingUSD: limit - sub.DailyUsageUSD,
Percentage: (sub.DailyUsageUSD / limit) * 100,
WindowStart: *sub.DailyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Daily.RemainingUSD < 0 {
@@ -488,12 +487,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
limit := *group.WeeklyLimitUSD
resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour)
progress.Weekly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.WeeklyUsageUSD,
RemainingUSD: limit - sub.WeeklyUsageUSD,
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
WindowStart: *sub.WeeklyWindowStart,
ResetsAt: resetsAt,
LimitUSD: limit,
UsedUSD: sub.WeeklyUsageUSD,
RemainingUSD: limit - sub.WeeklyUsageUSD,
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
WindowStart: *sub.WeeklyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Weekly.RemainingUSD < 0 {
@@ -512,12 +511,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
limit := *group.MonthlyLimitUSD
resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour)
progress.Monthly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.MonthlyUsageUSD,
RemainingUSD: limit - sub.MonthlyUsageUSD,
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
WindowStart: *sub.MonthlyWindowStart,
ResetsAt: resetsAt,
LimitUSD: limit,
UsedUSD: sub.MonthlyUsageUSD,
RemainingUSD: limit - sub.MonthlyUsageUSD,
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
WindowStart: *sub.MonthlyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Monthly.RemainingUSD < 0 {

View File

@@ -125,6 +125,7 @@ func (s *UpdateService) CheckUpdate(ctx context.Context, force bool) (*UpdateInf
}
// PerformUpdate downloads and applies the update
// Uses atomic file replacement pattern for safe in-place updates
func (s *UpdateService) PerformUpdate(ctx context.Context) error {
info, err := s.CheckUpdate(ctx, true)
if err != nil {
@@ -173,8 +174,11 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
return fmt.Errorf("failed to resolve symlinks: %w", err)
}
// Create temp directory for extraction
tempDir, err := os.MkdirTemp("", "sub2api-update-*")
exeDir := filepath.Dir(exePath)
// Create temp directory in the SAME directory as executable
// This ensures os.Rename is atomic (same filesystem)
tempDir, err := os.MkdirTemp(exeDir, ".sub2api-update-*")
if err != nil {
return fmt.Errorf("failed to create temp dir: %w", err)
}
@@ -199,23 +203,36 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
return fmt.Errorf("extraction failed: %w", err)
}
// Backup current binary
backupFile := exePath + ".backup"
if err := os.Rename(exePath, backupFile); err != nil {
return fmt.Errorf("backup failed: %w", err)
}
// Replace with new binary
if err := copyFile(newBinaryPath, exePath); err != nil {
os.Rename(backupFile, exePath)
return fmt.Errorf("replace failed: %w", err)
}
// Make executable
if err := os.Chmod(exePath, 0755); err != nil {
// Set executable permission before replacement
if err := os.Chmod(newBinaryPath, 0755); err != nil {
return fmt.Errorf("chmod failed: %w", err)
}
// Atomic replacement using rename pattern:
// 1. Rename current -> backup (atomic on Unix)
// 2. Rename new -> current (atomic on Unix, same filesystem)
// If step 2 fails, restore backup
backupPath := exePath + ".backup"
// Remove old backup if exists
os.Remove(backupPath)
// Step 1: Move current binary to backup
if err := os.Rename(exePath, backupPath); err != nil {
return fmt.Errorf("backup failed: %w", err)
}
// Step 2: Move new binary to target location (atomic, same filesystem)
if err := os.Rename(newBinaryPath, exePath); err != nil {
// Restore backup on failure
if restoreErr := os.Rename(backupPath, exePath); restoreErr != nil {
return fmt.Errorf("replace failed and restore failed: %w (restore error: %v)", err, restoreErr)
}
return fmt.Errorf("replace failed (restored backup): %w", err)
}
// Success - backup file is kept for rollback capability
// It will be cleaned up on next successful update
return nil
}
@@ -515,23 +532,6 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
return err
}
func copyFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, in)
return err
}
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
data, err := s.rdb.Get(ctx, updateCacheKey).Result()
if err != nil {

View File

@@ -0,0 +1,54 @@
package service
import (
"sub2api/internal/config"
"github.com/google/wire"
)
// ProvidePricingService creates and initializes PricingService
func ProvidePricingService(cfg *config.Config) (*PricingService, error) {
svc := NewPricingService(cfg)
if err := svc.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error())
}
return svc, nil
}
// ProvideEmailQueueService creates EmailQueueService with default worker count
func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
return NewEmailQueueService(emailService, 3)
}
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewApiKeyService,
NewGroupService,
NewAccountService,
NewProxyService,
NewRedeemService,
NewUsageService,
ProvidePricingService,
NewBillingService,
NewBillingCacheService,
NewAdminService,
NewGatewayService,
NewOAuthService,
NewRateLimitService,
NewAccountUsageService,
NewAccountTestService,
NewSettingService,
NewEmailService,
ProvideEmailQueueService,
NewTurnstileService,
NewSubscriptionService,
NewConcurrencyService,
NewIdentityService,
// Provide the Services container struct
wire.Struct(new(Services), "*"),
)

8
backend/tools.go Normal file
View File

@@ -0,0 +1,8 @@
//go:build tools
// +build tools
package tools
import (
_ "github.com/google/wire/cmd/wire"
)

View File

@@ -96,7 +96,7 @@ services:
# PostgreSQL Database
# ===========================================================================
postgres:
image: postgres:15-alpine
image: postgres:18-alpine
container_name: sub2api-postgres
restart: unless-stopped
volumes:

View File

@@ -128,6 +128,15 @@ declare -A MSG_ZH=(
["server_port_hint"]="建议使用 1024-65535 之间的端口"
["server_config_summary"]="服务器配置"
["invalid_port"]="无效端口号,请输入 1-65535 之间的数字"
# Service management
["starting_service"]="正在启动服务..."
["service_started"]="服务已启动"
["service_start_failed"]="服务启动失败,请检查日志"
["enabling_autostart"]="正在设置开机自启..."
["autostart_enabled"]="开机自启已启用"
["getting_public_ip"]="正在获取公网 IP..."
["public_ip_failed"]="无法获取公网 IP使用本地 IP"
)
# English strings
@@ -225,6 +234,15 @@ declare -A MSG_EN=(
["server_port_hint"]="Recommended range: 1024-65535"
["server_config_summary"]="Server configuration"
["invalid_port"]="Invalid port number, please enter a number between 1-65535"
# Service management
["starting_service"]="Starting service..."
["service_started"]="Service started"
["service_start_failed"]="Service failed to start, please check logs"
["enabling_autostart"]="Enabling auto-start on boot..."
["autostart_enabled"]="Auto-start enabled"
["getting_public_ip"]="Getting public IP..."
["public_ip_failed"]="Failed to get public IP, using local IP"
)
# Get message based on current language
@@ -254,9 +272,11 @@ print_error() {
echo -e "${RED}[$(msg 'error')]${NC} $1"
}
# Check if running interactively (stdin is a terminal)
# Check if running interactively (can access terminal)
# When piped (curl | bash), stdin is not a terminal, but /dev/tty may still be available
is_interactive() {
[ -t 0 ]
# Check if /dev/tty is available (works even when piped)
[ -e /dev/tty ] && [ -r /dev/tty ] && [ -w /dev/tty ]
}
# Select language
@@ -276,7 +296,7 @@ select_language() {
echo " 2) $(msg 'lang_en')"
echo ""
read -p "$(msg 'enter_choice'): " lang_input
read -p "$(msg 'enter_choice'): " lang_input < /dev/tty
case "$lang_input" in
2|en|EN|english|English)
@@ -317,7 +337,7 @@ configure_server() {
# Server host
echo -e "${YELLOW}$(msg 'server_host_hint')${NC}"
read -p "$(msg 'server_host_prompt') [${SERVER_HOST}]: " input_host
read -p "$(msg 'server_host_prompt') [${SERVER_HOST}]: " input_host < /dev/tty
if [ -n "$input_host" ]; then
SERVER_HOST="$input_host"
fi
@@ -327,7 +347,7 @@ configure_server() {
# Server port
echo -e "${YELLOW}$(msg 'server_port_hint')${NC}"
while true; do
read -p "$(msg 'server_port_prompt') [${SERVER_PORT}]: " input_port
read -p "$(msg 'server_port_prompt') [${SERVER_PORT}]: " input_port < /dev/tty
if [ -z "$input_port" ]; then
# Use default
break
@@ -566,13 +586,61 @@ prepare_for_setup() {
print_success "$(msg 'ready_for_setup')"
}
# Get public IP address
get_public_ip() {
print_info "$(msg 'getting_public_ip')"
# Try to get public IP from ipinfo.io
local response
response=$(curl -s --connect-timeout 5 --max-time 10 "https://ipinfo.io/json" 2>/dev/null)
if [ -n "$response" ]; then
# Extract IP from JSON response using grep and sed (no jq dependency)
PUBLIC_IP=$(echo "$response" | grep -o '"ip": *"[^"]*"' | sed 's/"ip": *"\([^"]*\)"/\1/')
if [ -n "$PUBLIC_IP" ]; then
print_success "Public IP: $PUBLIC_IP"
return 0
fi
fi
# Fallback to local IP
print_warning "$(msg 'public_ip_failed')"
PUBLIC_IP=$(hostname -I 2>/dev/null | awk '{print $1}' || echo "YOUR_SERVER_IP")
return 1
}
# Start service
start_service() {
print_info "$(msg 'starting_service')"
if systemctl start sub2api; then
print_success "$(msg 'service_started')"
return 0
else
print_error "$(msg 'service_start_failed')"
print_info "sudo journalctl -u sub2api -n 50"
return 1
fi
}
# Enable service auto-start
enable_autostart() {
print_info "$(msg 'enabling_autostart')"
if systemctl enable sub2api 2>/dev/null; then
print_success "$(msg 'autostart_enabled')"
return 0
else
print_warning "Failed to enable auto-start"
return 1
fi
}
# Print completion message
print_completion() {
local ip_addr
ip_addr=$(hostname -I 2>/dev/null | awk '{print $1}' || echo "YOUR_SERVER_IP")
# Use PUBLIC_IP which was set by get_public_ip()
# Determine display address
local display_host="$ip_addr"
local display_host="${PUBLIC_IP:-YOUR_SERVER_IP}"
if [ "$SERVER_HOST" = "127.0.0.1" ]; then
display_host="127.0.0.1"
fi
@@ -586,21 +654,9 @@ print_completion() {
echo "$(msg 'server_config_summary'): ${SERVER_HOST}:${SERVER_PORT}"
echo ""
echo "=============================================="
echo " $(msg 'next_steps')"
echo " $(msg 'step4_open_wizard')"
echo "=============================================="
echo ""
echo " 1. $(msg 'step1_check_services')"
echo " sudo systemctl status postgresql"
echo " sudo systemctl status redis"
echo ""
echo " 2. $(msg 'step2_start_service')"
echo " sudo systemctl start sub2api"
echo ""
echo " 3. $(msg 'step3_enable_autostart')"
echo " sudo systemctl enable sub2api"
echo ""
echo " 4. $(msg 'step4_open_wizard')"
echo ""
print_info " http://${display_host}:${SERVER_PORT}"
echo ""
echo " $(msg 'wizard_guide')"
@@ -667,7 +723,7 @@ uninstall() {
exit 1
fi
else
read -p "$(msg 'are_you_sure') " -n 1 -r
read -p "$(msg 'are_you_sure') " -n 1 -r < /dev/tty
echo
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
print_info "$(msg 'uninstall_cancelled')"
@@ -752,6 +808,9 @@ main() {
setup_directories
install_service
prepare_for_setup
get_public_ip
start_service
enable_autostart
print_completion
}

View File

@@ -2,7 +2,7 @@
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<link rel="icon" type="image/png" href="/logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Sub2API - AI API Gateway</title>
</head>

View File

@@ -11,6 +11,7 @@ import type {
PaginatedResponse,
AccountUsageInfo,
WindowStats,
ClaudeModel,
} from '@/types';
/**
@@ -247,6 +248,16 @@ export async function setSchedulable(id: number, schedulable: boolean): Promise<
return data;
}
/**
* Get available models for an account
* @param id - Account ID
* @returns List of available models for this account
*/
export async function getAvailableModels(id: number): Promise<ClaudeModel[]> {
const { data } = await apiClient.get<ClaudeModel[]>(`/admin/accounts/${id}/models`);
return data;
}
export const accountsAPI = {
list,
getById,
@@ -262,6 +273,7 @@ export const accountsAPI = {
getTodayStats,
clearRateLimit,
setSchedulable,
getAvailableModels,
generateAuthUrl,
exchangeCode,
batchCreate,

View File

@@ -40,9 +40,42 @@ export async function checkUpdates(force = false): Promise<VersionInfo> {
return data;
}
export interface UpdateResult {
message: string;
need_restart: boolean;
}
/**
* Perform system update
* Downloads and applies the latest version
*/
export async function performUpdate(): Promise<UpdateResult> {
const { data } = await apiClient.post<UpdateResult>('/admin/system/update');
return data;
}
/**
* Rollback to previous version
*/
export async function rollback(): Promise<UpdateResult> {
const { data } = await apiClient.post<UpdateResult>('/admin/system/rollback');
return data;
}
/**
* Restart the service
*/
export async function restartService(): Promise<{ message: string }> {
const { data } = await apiClient.post<{ message: string }>('/admin/system/restart');
return data;
}
export const systemAPI = {
getVersion,
checkUpdates,
performUpdate,
rollback,
restartService,
};
export default systemAPI;

View File

@@ -36,6 +36,23 @@
</span>
</div>
<!-- Model Selection -->
<div class="space-y-1.5">
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.accounts.selectTestModel') }}
</label>
<select
v-model="selectedModelId"
:disabled="loadingModels || status === 'connecting'"
class="w-full px-3 py-2 text-sm rounded-lg border border-gray-300 dark:border-dark-500 bg-white dark:bg-dark-700 text-gray-900 dark:text-gray-100 focus:ring-2 focus:ring-primary-500 focus:border-primary-500 disabled:opacity-50 disabled:cursor-not-allowed"
>
<option v-if="loadingModels" value="">{{ t('common.loading') }}...</option>
<option v-for="model in availableModels" :key="model.id" :value="model.id">
{{ model.display_name }} ({{ model.id }})
</option>
</select>
</div>
<!-- Terminal Output -->
<div class="relative group">
<div
@@ -125,10 +142,10 @@
</button>
<button
@click="startTest"
:disabled="status === 'connecting'"
:disabled="status === 'connecting' || !selectedModelId"
:class="[
'px-4 py-2 text-sm font-medium rounded-lg transition-all flex items-center gap-2',
status === 'connecting'
status === 'connecting' || !selectedModelId
? 'bg-primary-400 text-white cursor-not-allowed'
: status === 'success'
? 'bg-green-500 hover:bg-green-600 text-white'
@@ -161,7 +178,8 @@
import { ref, watch, nextTick } from 'vue'
import { useI18n } from 'vue-i18n'
import Modal from '@/components/common/Modal.vue'
import type { Account } from '@/types'
import { adminAPI } from '@/api/admin'
import type { Account, ClaudeModel } from '@/types'
const { t } = useI18n()
@@ -184,17 +202,44 @@ const status = ref<'idle' | 'connecting' | 'success' | 'error'>('idle')
const outputLines = ref<OutputLine[]>([])
const streamingContent = ref('')
const errorMessage = ref('')
const availableModels = ref<ClaudeModel[]>([])
const selectedModelId = ref('')
const loadingModels = ref(false)
let eventSource: EventSource | null = null
// Reset state when modal opens
watch(() => props.show, (newVal) => {
if (newVal) {
// Load available models when modal opens
watch(() => props.show, async (newVal) => {
if (newVal && props.account) {
resetState()
await loadAvailableModels()
} else {
closeEventSource()
}
})
const loadAvailableModels = async () => {
if (!props.account) return
loadingModels.value = true
selectedModelId.value = '' // Reset selection before loading
try {
availableModels.value = await adminAPI.accounts.getAvailableModels(props.account.id)
// Default to first model (usually Sonnet)
if (availableModels.value.length > 0) {
// Try to select Sonnet as default, otherwise use first model
const sonnetModel = availableModels.value.find(m => m.id.includes('sonnet'))
selectedModelId.value = sonnetModel?.id || availableModels.value[0].id
}
} catch (error) {
console.error('Failed to load available models:', error)
// Fallback to empty list
availableModels.value = []
selectedModelId.value = ''
} finally {
loadingModels.value = false
}
}
const resetState = () => {
status.value = 'idle'
outputLines.value = []
@@ -227,7 +272,7 @@ const scrollToBottom = async () => {
}
const startTest = async () => {
if (!props.account) return
if (!props.account || !selectedModelId.value) return
resetState()
status.value = 'connecting'
@@ -247,7 +292,8 @@ const startTest = async () => {
headers: {
'Authorization': `Bearer ${localStorage.getItem('auth_token')}`,
'Content-Type': 'application/json'
}
},
body: JSON.stringify({ model_id: selectedModelId.value })
})
if (!response.ok) {

View File

@@ -418,6 +418,31 @@
</div>
</div>
<!-- Intercept Warmup Requests (all account types) -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.interceptWarmupRequests') }}</label>
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">{{ t('admin.accounts.interceptWarmupRequestsDesc') }}</p>
</div>
<button
type="button"
@click="interceptWarmupRequests = !interceptWarmupRequests"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
interceptWarmupRequests ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
interceptWarmupRequests ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.proxy') }}</label>
<ProxySelector
@@ -590,6 +615,7 @@ const allowedModels = ref<string[]>([])
const customErrorCodesEnabled = ref(false)
const selectedErrorCodes = ref<number[]>([])
const customErrorCodeInput = ref<number | null>(null)
const interceptWarmupRequests = ref(false)
// Common models for whitelist
const commonModels = [
@@ -758,6 +784,7 @@ const resetForm = () => {
customErrorCodesEnabled.value = false
selectedErrorCodes.value = []
customErrorCodeInput.value = null
interceptWarmupRequests.value = false
oauth.resetState()
oauthFlowRef.value?.reset()
}
@@ -801,6 +828,11 @@ const handleSubmit = async () => {
credentials.custom_error_codes = [...selectedErrorCodes.value]
}
// Add intercept warmup requests setting
if (interceptWarmupRequests.value) {
credentials.intercept_warmup_requests = true
}
form.credentials = credentials
submitting.value = true
@@ -847,11 +879,17 @@ const handleExchangeCode = async () => {
const extra = oauth.buildExtraInfo(tokenInfo)
// Merge interceptWarmupRequests into credentials
const credentials = {
...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
}
await adminAPI.accounts.create({
name: form.name,
platform: form.platform,
type: addMethod.value, // Use addMethod as type: 'oauth' or 'setup-token'
credentials: tokenInfo,
credentials,
extra,
proxy_id: form.proxy_id,
concurrency: form.concurrency,
@@ -901,11 +939,17 @@ const handleCookieAuth = async (sessionKey: string) => {
const extra = oauth.buildExtraInfo(tokenInfo)
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
// Merge interceptWarmupRequests into credentials
const credentials = {
...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
}
await adminAPI.accounts.create({
name: accountName,
platform: form.platform,
type: addMethod.value, // Use addMethod as type: 'oauth' or 'setup-token'
credentials: tokenInfo,
credentials,
extra,
proxy_id: form.proxy_id,
concurrency: form.concurrency,

View File

@@ -286,6 +286,31 @@
</div>
</div>
<!-- Intercept Warmup Requests (all account types) -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.interceptWarmupRequests') }}</label>
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">{{ t('admin.accounts.interceptWarmupRequestsDesc') }}</p>
</div>
<button
type="button"
@click="interceptWarmupRequests = !interceptWarmupRequests"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
interceptWarmupRequests ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
interceptWarmupRequests ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.proxy') }}</label>
<ProxySelector
@@ -401,6 +426,7 @@ const allowedModels = ref<string[]>([])
const customErrorCodesEnabled = ref(false)
const selectedErrorCodes = ref<number[]>([])
const customErrorCodeInput = ref<number | null>(null)
const interceptWarmupRequests = ref(false)
// Common models for whitelist
const commonModels = [
@@ -459,6 +485,10 @@ watch(() => props.account, (newAccount) => {
form.status = newAccount.status as 'active' | 'inactive'
form.group_ids = newAccount.group_ids || []
// Load intercept warmup requests setting (applies to all account types)
const credentials = newAccount.credentials as Record<string, unknown> | undefined
interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true
// Initialize API Key fields for apikey type
if (newAccount.type === 'apikey' && newAccount.credentials) {
const credentials = newAccount.credentials as Record<string, unknown>
@@ -630,6 +660,23 @@ const handleSubmit = async () => {
newCredentials.custom_error_codes = [...selectedErrorCodes.value]
}
// Add intercept warmup requests setting
if (interceptWarmupRequests.value) {
newCredentials.intercept_warmup_requests = true
}
updatePayload.credentials = newCredentials
} else {
// For oauth/setup-token types, only update intercept_warmup_requests if changed
const currentCredentials = props.account.credentials as Record<string, unknown> || {}
const newCredentials: Record<string, unknown> = { ...currentCredentials }
if (interceptWarmupRequests.value) {
newCredentials.intercept_warmup_requests = true
} else {
delete newCredentials.intercept_warmup_requests
}
updatePayload.credentials = newCredentials
}

View File

@@ -12,7 +12,8 @@
]"
:title="hasUpdate ? 'New version available' : 'Up to date'"
>
<span class="font-medium">v{{ currentVersion }}</span>
<span v-if="currentVersion" class="font-medium">v{{ currentVersion }}</span>
<span v-else class="font-medium w-12 h-3 bg-gray-200 dark:bg-dark-600 rounded animate-pulse"></span>
<!-- Update indicator -->
<span v-if="hasUpdate" class="relative flex h-2 w-2">
<span class="animate-ping absolute inline-flex h-full w-full rounded-full bg-amber-400 opacity-75"></span>
@@ -56,7 +57,8 @@
<!-- Version display - centered and prominent -->
<div class="text-center mb-4">
<div class="inline-flex items-center gap-2">
<span class="text-2xl font-bold text-gray-900 dark:text-white">v{{ currentVersion }}</span>
<span v-if="currentVersion" class="text-2xl font-bold text-gray-900 dark:text-white">v{{ currentVersion }}</span>
<span v-else class="text-2xl font-bold text-gray-400 dark:text-dark-500">--</span>
<!-- Show check mark when up to date -->
<span v-if="!hasUpdate" class="flex items-center justify-center w-5 h-5 rounded-full bg-green-100 dark:bg-green-900/30">
<svg class="w-3 h-3 text-green-600 dark:text-green-400" fill="currentColor" viewBox="0 0 20 20">
@@ -69,8 +71,63 @@
</p>
</div>
<!-- Update available for source build - show git pull hint -->
<div v-if="hasUpdate && !isReleaseBuild" class="space-y-2">
<!-- Priority 1: Update error (must check before hasUpdate) -->
<div v-if="updateError" class="space-y-2">
<div class="flex items-center gap-3 p-3 rounded-lg bg-red-50 dark:bg-red-900/20 border border-red-200 dark:border-red-800/50">
<div class="flex-shrink-0 w-8 h-8 rounded-full bg-red-100 dark:bg-red-900/50 flex items-center justify-center">
<svg class="w-4 h-4 text-red-600 dark:text-red-400" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" />
</svg>
</div>
<div class="flex-1 min-w-0">
<p class="text-sm font-medium text-red-700 dark:text-red-300">{{ t('version.updateFailed') }}</p>
<p class="text-xs text-red-600/70 dark:text-red-400/70 truncate">{{ updateError }}</p>
</div>
</div>
<!-- Retry button -->
<button
@click="handleUpdate"
:disabled="updating"
class="w-full flex items-center justify-center gap-2 px-4 py-2 rounded-lg text-sm font-medium text-white bg-red-500 hover:bg-red-600 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
>
{{ t('version.retry') }}
</button>
</div>
<!-- Priority 2: Update success - need restart -->
<div v-else-if="updateSuccess && needRestart" class="space-y-2">
<div class="flex items-center gap-3 p-3 rounded-lg bg-green-50 dark:bg-green-900/20 border border-green-200 dark:border-green-800/50">
<div class="flex-shrink-0 w-8 h-8 rounded-full bg-green-100 dark:bg-green-900/50 flex items-center justify-center">
<svg class="w-4 h-4 text-green-600 dark:text-green-400" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M5 13l4 4L19 7" />
</svg>
</div>
<div class="flex-1 min-w-0">
<p class="text-sm font-medium text-green-700 dark:text-green-300">{{ t('version.updateComplete') }}</p>
<p class="text-xs text-green-600/70 dark:text-green-400/70">{{ t('version.restartRequired') }}</p>
</div>
</div>
<!-- Restart button -->
<button
@click="handleRestart"
:disabled="restarting"
class="w-full flex items-center justify-center gap-2 px-4 py-2 rounded-lg text-sm font-medium text-white bg-green-500 hover:bg-green-600 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
>
<svg v-if="restarting" class="animate-spin h-4 w-4" fill="none" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
<svg v-else class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
</svg>
{{ restarting ? t('version.restarting') : t('version.restartNow') }}
</button>
</div>
<!-- Priority 3: Update available for source build - show git pull hint -->
<div v-else-if="hasUpdate && !isReleaseBuild" class="space-y-2">
<a
v-if="releaseInfo?.html_url && releaseInfo.html_url !== '#'"
:href="releaseInfo.html_url"
@@ -100,29 +157,53 @@
</div>
</div>
<!-- Update available for release build - show download link -->
<a
v-else-if="hasUpdate && isReleaseBuild && releaseInfo?.html_url && releaseInfo.html_url !== '#'"
:href="releaseInfo.html_url"
target="_blank"
rel="noopener noreferrer"
class="flex items-center gap-3 p-3 rounded-lg bg-amber-50 dark:bg-amber-900/20 border border-amber-200 dark:border-amber-800/50 hover:bg-amber-100 dark:hover:bg-amber-900/30 transition-colors group"
>
<div class="flex-shrink-0 w-8 h-8 rounded-full bg-amber-100 dark:bg-amber-900/50 flex items-center justify-center">
<svg class="w-4 h-4 text-amber-600 dark:text-amber-400" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<!-- Priority 4: Update available for release build - show update button -->
<div v-else-if="hasUpdate && isReleaseBuild" class="space-y-2">
<!-- Update info card -->
<div class="flex items-center gap-3 p-3 rounded-lg bg-amber-50 dark:bg-amber-900/20 border border-amber-200 dark:border-amber-800/50">
<div class="flex-shrink-0 w-8 h-8 rounded-full bg-amber-100 dark:bg-amber-900/50 flex items-center justify-center">
<svg class="w-4 h-4 text-amber-600 dark:text-amber-400" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
</svg>
</div>
<div class="flex-1 min-w-0">
<p class="text-sm font-medium text-amber-700 dark:text-amber-300">{{ t('version.updateAvailable') }}</p>
<p class="text-xs text-amber-600/70 dark:text-amber-400/70">v{{ latestVersion }}</p>
</div>
</div>
<!-- Update button -->
<button
@click="handleUpdate"
:disabled="updating"
class="w-full flex items-center justify-center gap-2 px-4 py-2 rounded-lg text-sm font-medium text-white bg-primary-500 hover:bg-primary-600 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
>
<svg v-if="updating" class="animate-spin h-4 w-4" fill="none" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
<svg v-else class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
</svg>
</div>
<div class="flex-1 min-w-0">
<p class="text-sm font-medium text-amber-700 dark:text-amber-300">{{ t('version.updateAvailable') }}</p>
<p class="text-xs text-amber-600/70 dark:text-amber-400/70">v{{ latestVersion }}</p>
</div>
<svg class="w-4 h-4 text-amber-500 dark:text-amber-400 group-hover:translate-x-0.5 transition-transform" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M9 5l7 7-7 7" />
</svg>
</a>
{{ updating ? t('version.updating') : t('version.updateNow') }}
</button>
<!-- GitHub link when up to date -->
<!-- View release link -->
<a
v-if="releaseInfo?.html_url && releaseInfo.html_url !== '#'"
:href="releaseInfo.html_url"
target="_blank"
rel="noopener noreferrer"
class="flex items-center justify-center gap-1 text-xs text-gray-500 dark:text-dark-400 hover:text-gray-700 dark:hover:text-dark-200 transition-colors"
>
{{ t('version.viewChangelog') }}
<svg class="w-3 h-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M10 6H6a2 2 0 00-2 2v10a2 2 0 002 2h10a2 2 0 002-2v-4M14 4h6m0 0v6m0-6L10 14" />
</svg>
</a>
</div>
<!-- Priority 5: Up to date - show GitHub link -->
<a
v-else-if="releaseInfo?.html_url && releaseInfo.html_url !== '#'"
:href="releaseInfo.html_url"
@@ -154,8 +235,8 @@
<script setup lang="ts">
import { ref, computed, onMounted, onBeforeUnmount } from 'vue';
import { useI18n } from 'vue-i18n';
import { useAuthStore } from '@/stores';
import { checkUpdates, type VersionInfo, type ReleaseInfo } from '@/api/admin/system';
import { useAuthStore, useAppStore } from '@/stores';
import { performUpdate, restartService } from '@/api/admin/system';
const { t } = useI18n();
@@ -164,18 +245,27 @@ const props = defineProps<{
}>();
const authStore = useAuthStore();
const appStore = useAppStore();
const isAdmin = computed(() => authStore.isAdmin);
const loading = ref(false);
const dropdownOpen = ref(false);
const dropdownRef = ref<HTMLElement | null>(null);
const currentVersion = ref('0.1.0');
const latestVersion = ref('0.1.0');
const hasUpdate = ref(false);
const releaseInfo = ref<ReleaseInfo | null>(null);
const buildType = ref('source'); // "source" or "release"
// Use store's cached version state
const loading = computed(() => appStore.versionLoading);
const currentVersion = computed(() => appStore.currentVersion || props.version || '');
const latestVersion = computed(() => appStore.latestVersion);
const hasUpdate = computed(() => appStore.hasUpdate);
const releaseInfo = computed(() => appStore.releaseInfo);
const buildType = computed(() => appStore.buildType);
// Update process states (local to this component)
const updating = ref(false);
const restarting = ref(false);
const needRestart = ref(false);
const updateError = ref('');
const updateSuccess = ref(false);
// Only show update check for release builds (binary/docker deployment)
const isReleaseBuild = computed(() => buildType.value === 'release');
@@ -191,22 +281,54 @@ function closeDropdown() {
async function refreshVersion(force = true) {
if (!isAdmin.value) return;
loading.value = true;
// Reset update states when refreshing
updateError.value = '';
updateSuccess.value = false;
needRestart.value = false;
await appStore.fetchVersion(force);
}
async function handleUpdate() {
if (updating.value) return;
updating.value = true;
updateError.value = '';
updateSuccess.value = false;
try {
const data: VersionInfo = await checkUpdates(force);
currentVersion.value = data.current_version;
latestVersion.value = data.latest_version;
buildType.value = data.build_type || 'source';
// Show update indicator for all build types
hasUpdate.value = data.has_update;
releaseInfo.value = data.release_info || null;
} catch (error) {
console.error('Failed to check updates:', error);
const result = await performUpdate();
updateSuccess.value = true;
needRestart.value = result.need_restart;
// Clear version cache to reflect update completed
appStore.clearVersionCache();
} catch (error: unknown) {
const err = error as { response?: { data?: { message?: string } }; message?: string };
updateError.value = err.response?.data?.message || err.message || t('version.updateFailed');
} finally {
loading.value = false;
updating.value = false;
}
}
async function handleRestart() {
if (restarting.value) return;
restarting.value = true;
try {
await restartService();
// Service will restart, page will reload automatically or show disconnected
} catch (error) {
// Expected - connection will be lost during restart
console.log('Service restarting...');
}
// Show restarting state for a while, then reload
setTimeout(() => {
window.location.reload();
}, 3000);
}
function handleClickOutside(event: MouseEvent) {
const target = event.target as Node;
const button = (event.target as Element).closest('button');
@@ -217,9 +339,8 @@ function handleClickOutside(event: MouseEvent) {
onMounted(() => {
if (isAdmin.value) {
refreshVersion(false);
} else if (props.version) {
currentVersion.value = props.version;
// Use cached version if available, otherwise fetch
appStore.fetchVersion(false);
}
document.addEventListener('click', handleClickOutside);
});

View File

@@ -108,7 +108,7 @@
</router-link>
<a
href="https://github.com/fangyuan99/sub2api"
href="https://github.com/Wei-Shaw/sub2api"
target="_blank"
rel="noopener noreferrer"
@click="closeDropdown"
@@ -207,7 +207,7 @@ const pageDescription = computed(() => {
});
function toggleMobileSidebar() {
appStore.toggleSidebar();
appStore.toggleMobileSidebar();
}
function toggleDropdown() {

View File

@@ -36,6 +36,7 @@
class="sidebar-link mb-1"
:class="{ 'sidebar-link-active': isActive(item.path) }"
:title="sidebarCollapsed ? item.label : undefined"
@click="handleMenuItemClick"
>
<component :is="item.icon" class="w-5 h-5 flex-shrink-0" />
<transition name="fade">
@@ -58,6 +59,7 @@
class="sidebar-link mb-1"
:class="{ 'sidebar-link-active': isActive(item.path) }"
:title="sidebarCollapsed ? item.label : undefined"
@click="handleMenuItemClick"
>
<component :is="item.icon" class="w-5 h-5 flex-shrink-0" />
<transition name="fade">
@@ -77,6 +79,7 @@
class="sidebar-link mb-1"
:class="{ 'sidebar-link-active': isActive(item.path) }"
:title="sidebarCollapsed ? item.label : undefined"
@click="handleMenuItemClick"
>
<component :is="item.icon" class="w-5 h-5 flex-shrink-0" />
<transition name="fade">
@@ -142,9 +145,9 @@ const appStore = useAppStore();
const authStore = useAuthStore();
const sidebarCollapsed = computed(() => appStore.sidebarCollapsed);
const mobileOpen = computed(() => appStore.mobileOpen);
const isAdmin = computed(() => authStore.isAdmin);
const isDark = ref(document.documentElement.classList.contains('dark'));
const mobileOpen = ref(false);
// Site settings
const siteName = ref('Sub2API');
@@ -303,7 +306,15 @@ function toggleTheme() {
}
function closeMobile() {
mobileOpen.value = false;
appStore.setMobileOpen(false);
}
function handleMenuItemClick() {
if (mobileOpen.value) {
setTimeout(() => {
appStore.setMobileOpen(false);
}, 150);
}
}
function isActive(path: string): boolean {

View File

@@ -266,6 +266,8 @@ export default {
sync: 'Sync',
in: 'In',
out: 'Out',
cacheRead: 'Read',
cacheWrite: 'Write',
rate: 'Rate',
original: 'Original',
billed: 'Billed',
@@ -543,6 +545,7 @@ export default {
title: 'Subscription Settings',
type: 'Billing Type',
typeHint: 'Standard billing deducts from user balance. Subscription mode uses quota limits instead.',
typeNotEditable: 'Billing type cannot be changed after group creation.',
standard: 'Standard (Balance)',
subscription: 'Subscription (Quota)',
dailyLimit: 'Daily Limit (USD)',
@@ -695,6 +698,8 @@ export default {
enterErrorCode: 'Enter error code (100-599)',
invalidErrorCode: 'Please enter a valid HTTP error code (100-599)',
errorCodeExists: 'This error code is already selected',
interceptWarmupRequests: 'Intercept Warmup Requests',
interceptWarmupRequestsDesc: 'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens',
proxy: 'Proxy',
noProxy: 'No Proxy',
concurrency: 'Concurrency',
@@ -776,6 +781,7 @@ export default {
copyOutput: 'Copy output',
startingTestForAccount: 'Starting test for account: {name}',
testAccountTypeLabel: 'Account type: {type}',
selectTestModel: 'Select Test Model',
testModel: 'claude-sonnet-4-5-20250929',
testPrompt: 'Prompt: "hi"',
},
@@ -816,8 +822,8 @@ export default {
standardAdd: 'Standard Add',
batchAdd: 'Quick Add',
batchInput: 'Proxy List',
batchInputPlaceholder: 'Enter one proxy per line in the following formats:\nsocks5://user:pass@192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass@proxy.example.com:443',
batchInputHint: 'Supports http, https, socks5 protocols. Format: protocol://[user:pass@]host:port',
batchInputPlaceholder: "Enter one proxy per line in the following formats:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
batchInputHint: "Supports http, https, socks5 protocols. Format: protocol://[user:pass{'@'}]host:port",
parsedCount: '{count} valid',
invalidCount: '{count} invalid',
duplicateCount: '{count} duplicate',
@@ -1023,9 +1029,18 @@ export default {
noReleaseNotes: 'No release notes',
viewUpdate: 'View Update',
viewRelease: 'View Release',
viewChangelog: 'View Changelog',
refresh: 'Refresh',
sourceMode: 'Source Build',
sourceModeHint: 'Update detection is disabled for source builds. Use git pull to update.',
sourceModeHint: 'Source build, use git pull to update',
updateNow: 'Update Now',
updating: 'Updating...',
updateComplete: 'Update Complete',
updateFailed: 'Update Failed',
restartRequired: 'Please restart the service to apply the update',
restartNow: 'Restart Now',
restarting: 'Restarting...',
retry: 'Retry',
},
// User Subscriptions Page

View File

@@ -266,6 +266,8 @@ export default {
sync: '同步',
in: '输入',
out: '输出',
cacheRead: '读取',
cacheWrite: '写入',
rate: '倍率',
original: '原始',
billed: '计费',
@@ -598,6 +600,7 @@ export default {
title: '订阅设置',
type: '计费类型',
typeHint: '标准计费从用户余额扣除。订阅模式使用配额限制。',
typeNotEditable: '分组创建后无法修改计费类型。',
standard: '标准(余额)',
subscription: '订阅(配额)',
dailyLimit: '每日限额USD',
@@ -785,6 +788,8 @@ export default {
enterErrorCode: '输入错误码 (100-599)',
invalidErrorCode: '请输入有效的 HTTP 错误码 (100-599)',
errorCodeExists: '该错误码已被选中',
interceptWarmupRequests: '拦截预热请求',
interceptWarmupRequestsDesc: '启用后,标题生成等预热请求将返回 mock 响应,不消耗上游 token',
proxy: '代理',
noProxy: '无代理',
concurrency: '并发数',
@@ -864,6 +869,7 @@ export default {
copyOutput: '复制输出',
startingTestForAccount: '开始测试账号:{name}',
testAccountTypeLabel: '账号类型:{type}',
selectTestModel: '选择测试模型',
testModel: 'claude-sonnet-4-5-20250929',
testPrompt: '提示词:"hi"',
},
@@ -941,8 +947,8 @@ export default {
standardAdd: '标准添加',
batchAdd: '快捷添加',
batchInput: '代理列表',
batchInputPlaceholder: '每行输入一个代理,支持以下格式:\nsocks5://user:pass@192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass@proxy.example.com:443',
batchInputHint: '支持 http、https、socks5 协议,格式:协议://[用户名:密码@]主机:端口',
batchInputPlaceholder: "每行输入一个代理,支持以下格式:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
batchInputHint: "支持 http、https、socks5 协议,格式:协议://[用户名:密码{'@'}]主机:端口",
parsedCount: '有效 {count} 个',
invalidCount: '无效 {count} 个',
duplicateCount: '重复 {count} 个',
@@ -1202,9 +1208,18 @@ export default {
noReleaseNotes: '暂无更新日志',
viewUpdate: '查看更新',
viewRelease: '查看发布',
viewChangelog: '查看更新日志',
refresh: '刷新',
sourceMode: '源码构建',
sourceModeHint: '源码构建模式不支持更新检测,请使用 git pull 更新代码。',
sourceModeHint: '源码构建请使用 git pull 更新',
updateNow: '立即更新',
updating: '正在更新...',
updateComplete: '更新完成',
updateFailed: '更新失败',
restartRequired: '请重启服务以应用更新',
restartNow: '立即重启',
restarting: '正在重启...',
retry: '重试',
},
// User Subscriptions Page

View File

@@ -9,4 +9,8 @@ const app = createApp(App)
app.use(createPinia())
app.use(router)
app.use(i18n)
app.mount('#app')
// 等待路由器完成初始导航后再挂载,避免竞态条件导致的空白渲染
router.isReady().then(() => {
app.mount('#app')
})

View File

@@ -305,9 +305,10 @@ router.beforeEach((to, _from, next) => {
// If route doesn't require auth, allow access
if (!requiresAuth) {
// If already authenticated and trying to access login/register, redirect to dashboard
// If already authenticated and trying to access login/register, redirect to appropriate dashboard
if (authStore.isAuthenticated && (to.path === '/login' || to.path === '/register')) {
next('/dashboard');
// Admin users go to admin dashboard, regular users go to user dashboard
next(authStore.isAdmin ? '/admin/dashboard' : '/dashboard');
return;
}
next();

View File

@@ -6,14 +6,25 @@
import { defineStore } from 'pinia';
import { ref, computed } from 'vue';
import type { Toast, ToastType } from '@/types';
import { checkUpdates as checkUpdatesAPI, type VersionInfo, type ReleaseInfo } from '@/api/admin/system';
export const useAppStore = defineStore('app', () => {
// ==================== State ====================
const sidebarCollapsed = ref<boolean>(false);
const mobileOpen = ref<boolean>(false);
const loading = ref<boolean>(false);
const toasts = ref<Toast[]>([]);
// Version cache state
const versionLoaded = ref<boolean>(false);
const versionLoading = ref<boolean>(false);
const currentVersion = ref<string>('');
const latestVersion = ref<string>('');
const hasUpdate = ref<boolean>(false);
const buildType = ref<string>('source');
const releaseInfo = ref<ReleaseInfo | null>(null);
// Auto-incrementing ID for toasts
let toastIdCounter = 0;
@@ -40,6 +51,21 @@ export const useAppStore = defineStore('app', () => {
sidebarCollapsed.value = collapsed;
}
/**
* Toggle mobile sidebar open state
*/
function toggleMobileSidebar(): void {
mobileOpen.value = !mobileOpen.value;
}
/**
* Set mobile sidebar open state explicitly
* @param open - Whether mobile sidebar should be open
*/
function setMobileOpen(open: boolean): void {
mobileOpen.value = open;
}
/**
* Set global loading state
* @param isLoading - Whether app is in loading state
@@ -192,20 +218,82 @@ export const useAppStore = defineStore('app', () => {
toasts.value = [];
}
// ==================== Version Management ====================
/**
* Fetch version info (uses cache unless force=true)
* @param force - Force refresh from API
*/
async function fetchVersion(force = false): Promise<VersionInfo | null> {
// Return cached data if available and not forcing refresh
if (versionLoaded.value && !force) {
return {
current_version: currentVersion.value,
latest_version: latestVersion.value,
has_update: hasUpdate.value,
build_type: buildType.value,
release_info: releaseInfo.value || undefined,
cached: true,
};
}
// Prevent duplicate requests
if (versionLoading.value) {
return null;
}
versionLoading.value = true;
try {
const data = await checkUpdatesAPI(force);
currentVersion.value = data.current_version;
latestVersion.value = data.latest_version;
hasUpdate.value = data.has_update;
buildType.value = data.build_type || 'source';
releaseInfo.value = data.release_info || null;
versionLoaded.value = true;
return data;
} catch (error) {
console.error('Failed to fetch version:', error);
return null;
} finally {
versionLoading.value = false;
}
}
/**
* Clear version cache (e.g., after update)
*/
function clearVersionCache(): void {
versionLoaded.value = false;
hasUpdate.value = false;
}
// ==================== Return Store API ====================
return {
// State
sidebarCollapsed,
mobileOpen,
loading,
toasts,
// Version state
versionLoaded,
versionLoading,
currentVersion,
latestVersion,
hasUpdate,
buildType,
releaseInfo,
// Computed
hasActiveToasts,
// Actions
toggleSidebar,
setSidebarCollapsed,
toggleMobileSidebar,
setMobileOpen,
setLoading,
showToast,
showSuccess,
@@ -217,5 +305,9 @@ export const useAppStore = defineStore('app', () => {
withLoading,
withLoadingAndError,
reset,
// Version actions
fetchVersion,
clearVersionCache,
};
});

View File

@@ -285,6 +285,14 @@ export type AccountType = 'oauth' | 'setup-token' | 'apikey';
export type OAuthAddMethod = 'oauth' | 'setup-token';
export type ProxyProtocol = 'http' | 'https' | 'socks5';
// Claude Model type (returned by /v1/models and account models API)
export interface ClaudeModel {
id: string;
type: string;
display_name: string;
created_at: string;
}
export interface Proxy {
id: number;
name: string;

View File

@@ -282,7 +282,7 @@ const siteSubtitle = ref('AI API Gateway Platform');
const isDark = ref(document.documentElement.classList.contains('dark'));
// GitHub URL
const githubUrl = 'https://github.com/fangyuan99/sub2api';
const githubUrl = 'https://github.com/Wei-Shaw/sub2api';
// Auth state
const isAuthenticated = computed(() => authStore.isAuthenticated);

View File

@@ -406,7 +406,7 @@ const lineOptions = computed(() => ({
// Model chart data
const modelChartData = computed(() => {
if (!modelStats.value.length) return null
if (!modelStats.value?.length) return null
const colors = [
'#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6',
@@ -425,7 +425,7 @@ const modelChartData = computed(() => {
// Trend chart data
const trendChartData = computed(() => {
if (!trendData.value.length) return null
if (!trendData.value?.length) return null
return {
labels: trendData.value.map(d => d.date),
@@ -460,7 +460,7 @@ const trendChartData = computed(() => {
// User trend chart data
const userTrendChartData = computed(() => {
if (!userTrend.value.length) return null
if (!userTrend.value?.length) return null
// Group by user
const userGroups = new Map<string, { name: string; data: Map<string, number> }>()

View File

@@ -180,7 +180,7 @@
/>
<p class="input-hint">{{ t('admin.groups.rateMultiplierHint') }}</p>
</div>
<div class="flex items-center gap-3">
<div v-if="createForm.subscription_type !== 'subscription'" class="flex items-center gap-3">
<button
type="button"
@click="createForm.is_exclusive = !createForm.is_exclusive"
@@ -323,7 +323,7 @@
class="input"
/>
</div>
<div class="flex items-center gap-3">
<div v-if="editForm.subscription_type !== 'subscription'" class="flex items-center gap-3">
<button
type="button"
@click="editForm.is_exclusive = !editForm.is_exclusive"
@@ -360,8 +360,9 @@
<Select
v-model="editForm.subscription_type"
:options="subscriptionTypeOptions"
:disabled="true"
/>
<p class="input-hint">{{ t('admin.groups.subscription.typeHint') }}</p>
<p class="input-hint">{{ t('admin.groups.subscription.typeNotEditable') }}</p>
</div>
<!-- Subscription limits (only show when subscription type is selected) -->
@@ -676,16 +677,11 @@ const confirmDelete = async () => {
}
}
// 监听 subscription_type 变化,配额模式时重置 rate_multiplier 为 1
// 监听 subscription_type 变化,订阅模式时重置 rate_multiplier 为 1is_exclusive 为 true
watch(() => createForm.subscription_type, (newVal) => {
if (newVal === 'subscription') {
createForm.rate_multiplier = 1.0
}
})
watch(() => editForm.subscription_type, (newVal) => {
if (newVal === 'subscription') {
editForm.rate_multiplier = 1.0
createForm.is_exclusive = true
}
})

View File

@@ -647,10 +647,10 @@ const parseProxyUrl = (line: string): {
return {
protocol: protocol.toLowerCase() as ProxyProtocol,
host,
host: host.trim(),
port: portNum,
username: username || '',
password: password || ''
username: username?.trim() || '',
password: password?.trim() || ''
}
}
@@ -714,9 +714,12 @@ const handleCreateProxy = async () => {
submitting.value = true
try {
await adminAPI.proxies.create({
...createForm,
username: createForm.username || null,
password: createForm.password || null
name: createForm.name.trim(),
protocol: createForm.protocol,
host: createForm.host.trim(),
port: createForm.port,
username: createForm.username.trim() || null,
password: createForm.password.trim() || null
})
appStore.showSuccess(t('admin.proxies.proxyCreated'))
closeCreateModal()
@@ -752,17 +755,18 @@ const handleUpdateProxy = async () => {
submitting.value = true
try {
const updateData: any = {
name: editForm.name,
name: editForm.name.trim(),
protocol: editForm.protocol,
host: editForm.host,
host: editForm.host.trim(),
port: editForm.port,
username: editForm.username || null,
username: editForm.username.trim() || null,
status: editForm.status
}
// Only include password if it was changed
if (editForm.password) {
updateData.password = editForm.password
const trimmedPassword = editForm.password.trim()
if (trimmedPassword) {
updateData.password = trimmedPassword
}
await adminAPI.proxies.update(editingProxy.value.id, updateData)

View File

@@ -43,13 +43,12 @@
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
</div>
<div>
<div class="min-w-0 flex-1">
<p class="text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('usage.totalCost') }}</p>
<div class="flex items-baseline gap-2">
<p class="text-xl font-bold text-green-600 dark:text-green-400">${{ (usageStats?.total_actual_cost || 0).toFixed(4) }}</p>
<span class="text-xs text-gray-400 dark:text-gray-500 line-through">${{ (usageStats?.total_cost || 0).toFixed(4) }}</span>
</div>
<p class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.actualCost') }} / {{ t('usage.standardCost') }}</p>
<p class="text-xl font-bold text-green-600 dark:text-green-400">${{ (usageStats?.total_actual_cost || 0).toFixed(4) }}</p>
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t('usage.actualCost') }} / <span class="line-through">${{ (usageStats?.total_cost || 0).toFixed(4) }}</span> {{ t('usage.standardCost') }}
</p>
</div>
</div>
</div>
@@ -195,17 +194,40 @@
</template>
<template #cell-tokens="{ row }">
<div class="text-sm">
<div class="flex items-center gap-1">
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.in') }}</span>
<span class="font-medium text-gray-900 dark:text-white">{{ row.input_tokens.toLocaleString() }}</span>
<span class="text-gray-400 dark:text-gray-500">/</span>
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.out') }}</span>
<span class="font-medium text-gray-900 dark:text-white">{{ row.output_tokens.toLocaleString() }}</span>
<div class="text-sm space-y-1.5">
<!-- Input / Output Tokens -->
<div class="flex items-center gap-2">
<!-- Input -->
<div class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-emerald-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
</svg>
<span class="font-medium text-gray-900 dark:text-white">{{ row.input_tokens.toLocaleString() }}</span>
</div>
<!-- Output -->
<div class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-violet-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 10l7-7m0 0l7 7m-7-7v18" />
</svg>
<span class="font-medium text-gray-900 dark:text-white">{{ row.output_tokens.toLocaleString() }}</span>
</div>
</div>
<div v-if="row.cache_read_tokens > 0" class="flex items-center gap-1 text-blue-600 dark:text-blue-400">
<span>{{ t('dashboard.cache') }}</span>
<span class="font-medium">{{ row.cache_read_tokens.toLocaleString() }}</span>
<!-- Cache Tokens (Read + Write) -->
<div v-if="row.cache_read_tokens > 0 || row.cache_creation_tokens > 0" class="flex items-center gap-2">
<!-- Cache Read -->
<div v-if="row.cache_read_tokens > 0" class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-sky-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 8h14M5 8a2 2 0 110-4h14a2 2 0 110 4M5 8v10a2 2 0 002 2h10a2 2 0 002-2V8m-9 4h4" />
</svg>
<span class="text-sky-600 dark:text-sky-400 font-medium">{{ formatCacheTokens(row.cache_read_tokens) }}</span>
</div>
<!-- Cache Write -->
<div v-if="row.cache_creation_tokens > 0" class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-amber-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" />
</svg>
<span class="text-amber-600 dark:text-amber-400 font-medium">{{ formatCacheTokens(row.cache_creation_tokens) }}</span>
</div>
</div>
</div>
</template>
@@ -458,6 +480,16 @@ const formatTokens = (value: number): string => {
return value.toLocaleString()
}
// Compact format for cache tokens in table cells
const formatCacheTokens = (value: number): string => {
if (value >= 1_000_000) {
return `${(value / 1_000_000).toFixed(1)}M`
} else if (value >= 1_000) {
return `${(value / 1_000).toFixed(1)}K`
}
return value.toLocaleString()
}
const formatDateTime = (dateString: string): string => {
const date = new Date(dateString)
return date.toLocaleString('en-US', {
@@ -538,7 +570,7 @@ const exportToCSV = () => {
return
}
const headers = ['User', 'API Key', 'Model', 'Type', 'Input Tokens', 'Output Tokens', 'Cache Tokens', 'Total Cost', 'Billing Type', 'Duration (ms)', 'Time']
const headers = ['User', 'API Key', 'Model', 'Type', 'Input Tokens', 'Output Tokens', 'Cache Read Tokens', 'Cache Write Tokens', 'Total Cost', 'Billing Type', 'Duration (ms)', 'Time']
const rows = usageLogs.value.map(log => [
log.user?.email || '',
log.api_key?.name || '',
@@ -547,6 +579,7 @@ const exportToCSV = () => {
log.input_tokens,
log.output_tokens,
log.cache_read_tokens,
log.cache_creation_tokens,
log.total_cost.toFixed(6),
log.billing_type === 1 ? 'Subscription' : 'Balance',
log.duration_ms,

View File

@@ -531,7 +531,7 @@ const lineOptions = computed(() => ({
// Model chart data
const modelChartData = computed(() => {
if (!modelStats.value.length) return null
if (!modelStats.value?.length) return null
const colors = [
'#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6',
@@ -550,7 +550,7 @@ const modelChartData = computed(() => {
// Trend chart data
const trendChartData = computed(() => {
if (!trendData.value.length) return null
if (!trendData.value?.length) return null
return {
labels: trendData.value.map(d => d.date),
@@ -688,8 +688,9 @@ const loadChartData = async () => {
usageAPI.getDashboardModels({ start_date: startDate.value, end_date: endDate.value }),
])
trendData.value = trendResponse.trend
modelStats.value = modelResponse.models
// Ensure we always have arrays, even if API returns null
trendData.value = trendResponse.trend || []
modelStats.value = modelResponse.models || []
} catch (error) {
console.error('Error loading chart data:', error)
}

View File

@@ -43,13 +43,12 @@
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
</div>
<div>
<div class="min-w-0 flex-1">
<p class="text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('usage.totalCost') }}</p>
<div class="flex items-baseline gap-2">
<p class="text-xl font-bold text-green-600 dark:text-green-400">${{ (usageStats?.total_actual_cost || 0).toFixed(4) }}</p>
<span class="text-xs text-gray-400 dark:text-gray-500 line-through">${{ (usageStats?.total_cost || 0).toFixed(4) }}</span>
</div>
<p class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.actualCost') }} / {{ t('usage.standardCost') }}</p>
<p class="text-xl font-bold text-green-600 dark:text-green-400">${{ (usageStats?.total_actual_cost || 0).toFixed(4) }}</p>
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t('usage.actualCost') }} / <span class="line-through">${{ (usageStats?.total_cost || 0).toFixed(4) }}</span> {{ t('usage.standardCost') }}
</p>
</div>
</div>
</div>
@@ -138,17 +137,40 @@
</template>
<template #cell-tokens="{ row }">
<div class="text-sm">
<div class="flex items-center gap-1">
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.in') }}</span>
<span class="font-medium text-gray-900 dark:text-white">{{ row.input_tokens.toLocaleString() }}</span>
<span class="text-gray-400 dark:text-gray-500">/</span>
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.out') }}</span>
<span class="font-medium text-gray-900 dark:text-white">{{ row.output_tokens.toLocaleString() }}</span>
<div class="text-sm space-y-1.5">
<!-- Input / Output Tokens -->
<div class="flex items-center gap-2">
<!-- Input -->
<div class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-emerald-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
</svg>
<span class="font-medium text-gray-900 dark:text-white">{{ row.input_tokens.toLocaleString() }}</span>
</div>
<!-- Output -->
<div class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-violet-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 10l7-7m0 0l7 7m-7-7v18" />
</svg>
<span class="font-medium text-gray-900 dark:text-white">{{ row.output_tokens.toLocaleString() }}</span>
</div>
</div>
<div v-if="row.cache_read_tokens > 0" class="flex items-center gap-1 text-blue-600 dark:text-blue-400">
<span>{{ t('dashboard.cache') }}</span>
<span class="font-medium">{{ row.cache_read_tokens.toLocaleString() }}</span>
<!-- Cache Tokens (Read + Write) -->
<div v-if="row.cache_read_tokens > 0 || row.cache_creation_tokens > 0" class="flex items-center gap-2">
<!-- Cache Read -->
<div v-if="row.cache_read_tokens > 0" class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-sky-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 8h14M5 8a2 2 0 110-4h14a2 2 0 110 4M5 8v10a2 2 0 002 2h10a2 2 0 002-2V8m-9 4h4" />
</svg>
<span class="text-sky-600 dark:text-sky-400 font-medium">{{ formatCacheTokens(row.cache_read_tokens) }}</span>
</div>
<!-- Cache Write -->
<div v-if="row.cache_creation_tokens > 0" class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-amber-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" />
</svg>
<span class="text-amber-600 dark:text-amber-400 font-medium">{{ formatCacheTokens(row.cache_creation_tokens) }}</span>
</div>
</div>
</div>
</template>
@@ -332,6 +354,16 @@ const formatTokens = (value: number): string => {
return value.toLocaleString()
}
// Compact format for cache tokens in table cells
const formatCacheTokens = (value: number): string => {
if (value >= 1_000_000) {
return `${(value / 1_000_000).toFixed(1)}M`
} else if (value >= 1_000) {
return `${(value / 1_000).toFixed(1)}K`
}
return value.toLocaleString()
}
const formatDateTime = (dateString: string): string => {
const date = new Date(dateString)
return date.toLocaleString('en-US', {
@@ -416,13 +448,14 @@ const exportToCSV = () => {
return
}
const headers = ['Model', 'Type', 'Input Tokens', 'Output Tokens', 'Cache Tokens', 'Total Cost', 'Billing Type', 'First Token (ms)', 'Duration (ms)', 'Time']
const headers = ['Model', 'Type', 'Input Tokens', 'Output Tokens', 'Cache Read Tokens', 'Cache Write Tokens', 'Total Cost', 'Billing Type', 'First Token (ms)', 'Duration (ms)', 'Time']
const rows = usageLogs.value.map(log => [
log.model,
log.stream ? 'Stream' : 'Sync',
log.input_tokens,
log.output_tokens,
log.cache_read_tokens,
log.cache_creation_tokens,
log.total_cost.toFixed(6),
log.billing_type === 1 ? 'Subscription' : 'Balance',
log.first_token_ms ?? '',