Compare commits

...

23 Commits

Author SHA1 Message Date
shaw
7fd94ab78b fix: 修复usage页面未显示缓存写入的问题 2025-12-19 16:57:31 +08:00
shaw
078529e51e chore: 更新docker的postgres版本为18 2025-12-19 16:42:03 +08:00
shaw
23a4cf11c8 fix: 设置默认logo作为favicon 2025-12-19 16:41:00 +08:00
shaw
d1f0902ec0 feat(account): 支持账号级别拦截预热请求
- 新增 intercept_warmup_requests 配置项,存储在 credentials 字段
- 启用后,标题生成、Warmup 等预热请求返回 mock 响应,不消耗上游 token
- 前端支持所有账号类型(OAuth、Setup Token、API Key)的开关配置
- 修复 OAuth 凭证刷新时丢失非 token 配置的问题
2025-12-19 16:39:25 +08:00
shaw
ee86dbca9d feat(account): 账号测试支持选择模型
- 新增 GET /api/v1/admin/accounts/:id/models 接口获取账号可用模型
- 账号测试弹窗新增模型选择下拉框
- 测试时支持传入 model_id 参数,不传则默认使用 Sonnet
- API Key 账号支持根据 model_mapping 映射测试模型
- 将模型常量提取到 claude 包统一管理
2025-12-19 16:00:09 +08:00
Wesley Liddick
733d4c2b85 Merge pull request #6 from dexcoder6/main
fix(frontend): 修复移动端菜单栏和使用记录页面 UI 问题
2025-12-19 02:59:05 -05:00
dexcoder6
406d3f3cab fix(frontend): 修复移动端菜单栏和使用记录页面 UI 问题
- 修复移动端无法打开菜单栏的问题
  - 在 app.ts 中添加 mobileOpen 状态管理
  - 修复 AppHeader.vue 中移动端菜单按钮调用错误的方法
  - 修复 AppSidebar.vue 使用本地 ref 而非全局状态的问题

- 添加移动端菜单自动关闭功能
  - 点击菜单项后自动关闭侧边栏
  - 添加 150ms 延迟以显示关闭动画

- 修复使用记录页面总消费卡片溢出问题
  - 调整总消费卡片布局,将删除线价格移至说明行
  - 添加 min-w-0 flex-1 防止内容溢出
  - 保持与其他卡片高度一致

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-19 15:55:42 +08:00
shaw
1ed93a5fd0 refactor: 提取 Claude 客户端常量到独立包
- 新增 internal/pkg/claude 包统一管理 Claude Code 相关常量
  - 统一账号测试逻辑,所有账号类型使用相同的 Claude Code 风格请求
  - 网关服务使用常量包替换硬编码的 beta header 字符串
2025-12-19 15:22:52 +08:00
shaw
463ddea36f fix(frontend): 修复代理快捷添加弹窗的 i18n 解析错误
batchInputHint 中的 @ 符号需要使用 {'@'} 转义
2025-12-19 11:24:22 +08:00
shaw
e769f67699 fix(setup): 支持从配置文件读取 Setup Wizard 监听地址
Setup Wizard 之前硬编码使用 8080 端口,现在支持从 config.yaml 或
环境变量 (SERVER_HOST, SERVER_PORT) 读取监听地址,方便用户在端口
被占用时使用其他地址启动初始化向导。
2025-12-19 11:21:58 +08:00
shaw
52d2ae9708 feat(gateway): 添加 /v1/messages/count_tokens 端点
实现 Claude API 的 token 计数功能,支持 OAuth、SetupToken 和 ApiKey 三种账号类型。

特点:
- 校验订阅/余额(不扣费)
- 不计算用户和账号并发
- 不记录使用量
- 支持模型映射(ApiKey 账号)
- 支持 OAuth 账号的指纹管理和 401 重试
2025-12-19 11:12:41 +08:00
shaw
2e59998c51 fix: 代理表单字段保存时自动去除前后空格
前后端同时处理,防止因意外空格导致代理连接失败
2025-12-19 10:39:30 +08:00
shaw
32e58115cc fix(frontend): 修复代理快捷添加弹窗的 i18n 解析错误
转义 batchInputPlaceholder 中的 @ 符号,防止 Vue I18n 将其误解析为链接消息语法
2025-12-19 10:32:22 +08:00
shaw
ba27026399 docs: 调整源码编译步骤的顺序 2025-12-19 09:47:17 +08:00
shaw
c15b419c4c feat(backend): 添加 event_logging 接口直接返回200
将原本在nginx处理的遥测日志请求移至后端,
忽略Claude Code客户端发送的日志数据。
2025-12-19 09:39:57 +08:00
shaw
5bd27a5d17 fix(frontend): 优化分组表单中订阅模式的字段显示逻辑
- 订阅模式下隐藏 Exclusive 字段并默认为开启状态
- 编辑分组时禁用计费类型字段,防止修改
- 移除编辑表单中无用的 subscription_type watch
2025-12-19 08:41:30 +08:00
Wesley Liddick
0e7b8aab8c Merge pull request #4 from NepetaLemon/refactor/backend-wire-provider-sets
refactor(backend): 拆分 Wire ProviderSet
2025-12-18 19:27:49 -05:00
Forest
236908c03d refactor(backend): 拆分 Wire ProviderSet 2025-12-19 00:03:29 +08:00
shaw
67d028cf50 fix: 修复用户修改密码接口404问题
将后端路由与前端API调用对齐:
- /user/profile -> /users/me
- PUT /user/password -> POST /users/me/password
2025-12-18 22:59:49 +08:00
shaw
66ba487697 fix: 修复前端github项目地址 2025-12-18 22:47:42 +08:00
Wesley Liddick
8c7875aa4d Merge pull request #3 from NepetaLemon/refactor/backend-wire-bootstrap
refactor(backend): 引入 Wire 重构服务启动与依赖组装
2025-12-18 09:12:15 -05:00
shaw
145171464f fix: 修复前端多个 bug
1. 版本号闪烁问题
   - 将版本信息缓存到 Pinia store,避免每次路由切换都重新请求
   - 添加加载占位符,版本为空时显示骨架屏

2. 管理员登录跳转问题
   - 管理员登录后现在正确跳转到 /admin/dashboard
   - 普通用户仍跳转到 /dashboard

3. Dashboard 页面空白报错
   - 修复 API 返回 null 时访问 .length 导致的 TypeError
   - 为 computed 属性添加可选链操作符保护
   - 为数据赋值添加空数组默认值
2025-12-18 22:11:29 +08:00
Forest
e5aa676853 refactor(backend): 引入 Wire 重构服务启动与依赖组装 2025-12-18 22:07:17 +08:00
54 changed files with 2160 additions and 860 deletions

View File

@@ -16,6 +16,14 @@ English | [中文](README_CN.md)
--- ---
## Demo
Try Sub2API online: **https://v2.pincc.ai/**
| Email | Password |
|-------|----------|
| admin@sub2api.com | admin123 |
## Overview ## Overview
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding. Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
@@ -208,20 +216,19 @@ Build and run from source code for development or customization.
git clone https://github.com/Wei-Shaw/sub2api.git git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api cd sub2api
# 2. Build backend # 2. Build frontend
cd backend cd frontend
go build -o sub2api ./cmd/server
# 3. Build frontend
cd ../frontend
npm install npm install
npm run build npm run build
# 4. Copy frontend build to backend (for embedding) # 3. Copy frontend build to backend (for embedding)
cp -r dist ../backend/internal/web/ cp -r dist ../backend/internal/web/
# 5. Create configuration file # 4. Build backend (requires frontend dist to be present)
cd ../backend cd ../backend
go build -o sub2api ./cmd/server
# 5. Create configuration file
cp ../deploy/config.example.yaml ./config.yaml cp ../deploy/config.example.yaml ./config.yaml
# 6. Edit configuration # 6. Edit configuration

View File

@@ -16,6 +16,14 @@
--- ---
## 在线体验
体验地址:**https://v2.pincc.ai/**
| 邮箱 | 密码 |
|------|------|
| admin@sub2api.com | admin123 |
## 项目概述 ## 项目概述
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。 Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
@@ -208,20 +216,19 @@ docker-compose logs -f
git clone https://github.com/Wei-Shaw/sub2api.git git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api cd sub2api
# 2. 编译 # 2. 编译
cd backend cd frontend
go build -o sub2api ./cmd/server
# 3. 编译前端
cd ../frontend
npm install npm install
npm run build npm run build
# 4. 复制前端构建产物到后端(用于嵌入) # 3. 复制前端构建产物到后端(用于嵌入)
cp -r dist ../backend/internal/web/ cp -r dist ../backend/internal/web/
# 5. 创建配置文件 # 4. 编译后端(需要前端 dist 目录存在)
cd ../backend cd ../backend
go build -o sub2api ./cmd/server
# 5. 创建配置文件
cp ../deploy/config.example.yaml ./config.yaml cp ../deploy/config.example.yaml ./config.yaml
# 6. 编辑配置 # 6. 编辑配置

6
backend/Makefile Normal file
View File

@@ -0,0 +1,6 @@
.PHONY: wire
wire:
@echo "生成 Wire 代码..."
@cd cmd/server && go generate
@echo "Wire 代码生成完成"

View File

@@ -1,8 +1,11 @@
package main package main
//go:generate go run github.com/google/wire/cmd/wire
import ( import (
"context" "context"
_ "embed" _ "embed"
"errors"
"flag" "flag"
"log" "log"
"net/http" "net/http"
@@ -15,18 +18,10 @@ import (
"sub2api/internal/config" "sub2api/internal/config"
"sub2api/internal/handler" "sub2api/internal/handler"
"sub2api/internal/middleware" "sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"sub2api/internal/setup" "sub2api/internal/setup"
"sub2api/internal/web" "sub2api/internal/web"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
) )
//go:embed VERSION //go:embed VERSION
@@ -100,8 +95,10 @@ func runSetupServer() {
r.Use(web.ServeEmbeddedFrontend()) r.Use(web.ServeEmbeddedFrontend())
} }
addr := ":8080" // Get server address from config.yaml or environment variables (SERVER_HOST, SERVER_PORT)
log.Printf("Setup wizard available at http://localhost%s", addr) // This allows users to run setup on a different address if needed
addr := config.GetServerAddress()
log.Printf("Setup wizard available at http://%s", addr)
log.Println("Complete the setup wizard to configure Sub2API") log.Println("Complete the setup wizard to configure Sub2API")
if err := r.Run(addr); err != nil { if err := r.Run(addr); err != nil {
@@ -110,78 +107,25 @@ func runSetupServer() {
} }
func runMainServer() { func runMainServer() {
// 加载配置
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 初始化时区(类似 PHP 的 date_default_timezone_set
if err := timezone.Init(cfg.Timezone); err != nil {
log.Fatalf("Failed to initialize timezone: %v", err)
}
// 初始化数据库
db, err := initDB(cfg)
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
// 初始化Redis
rdb := initRedis(cfg)
// 初始化Repository
repos := repository.NewRepositories(db)
// 初始化Service
services := service.NewServices(repos, rdb, cfg)
// 初始化Handler
buildInfo := handler.BuildInfo{ buildInfo := handler.BuildInfo{
Version: Version, Version: Version,
BuildType: BuildType, BuildType: BuildType,
} }
handlers := handler.NewHandlers(services, repos, rdb, buildInfo)
// 设置Gin模式 app, err := initializeApplication(buildInfo)
if cfg.Server.Mode == "release" { if err != nil {
gin.SetMode(gin.ReleaseMode) log.Fatalf("Failed to initialize application: %v", err)
}
// 创建路由
r := gin.New()
r.Use(gin.Recovery())
r.Use(middleware.Logger())
r.Use(middleware.CORS())
// 注册路由
registerRoutes(r, handlers, services, repos)
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
} }
defer app.Cleanup()
// 启动服务器 // 启动服务器
srv := &http.Server{
Addr: cfg.Server.Address(),
Handler: r,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
// 优雅关闭
go func() { go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := app.Server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start server: %v", err) log.Fatalf("Failed to start server: %v", err)
} }
}() }()
log.Printf("Server started on %s", cfg.Server.Address()) log.Printf("Server started on %s", app.Server.Addr)
// 等待中断信号 // 等待中断信号
quit := make(chan os.Signal, 1) quit := make(chan os.Signal, 1)
@@ -193,289 +137,9 @@ func runMainServer() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := srv.Shutdown(ctx); err != nil { if err := app.Server.Shutdown(ctx); err != nil {
log.Fatalf("Server forced to shutdown: %v", err) log.Fatalf("Server forced to shutdown: %v", err)
} }
log.Println("Server exited") log.Println("Server exited")
} }
func initDB(cfg *config.Config) (*gorm.DB, error) {
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
return nil, err
}
return db, nil
}
func initRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
// API v1
v1 := r.Group("/api/v1")
{
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的接口
authenticated := v1.Group("")
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
{
// 当前用户信息
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
// 管理员接口
admin := v1.Group("/admin")
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
{
// 仪表盘
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
// 用户管理
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
// 分组管理
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
// 账号管理
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
// OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
// 卡密管理
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
// 系统设置
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
}
// 系统管理
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
// 订阅管理
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
// 使用记录管理
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
}
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
}
}

103
backend/cmd/server/wire.go Normal file
View File

@@ -0,0 +1,103 @@
//go:build wireinject
// +build wireinject
package main
import (
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"context"
"log"
"net/http"
"time"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
type Application struct {
Server *http.Server
Cleanup func()
}
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
wire.Build(
// 基础设施层 ProviderSets
config.ProviderSet,
infrastructure.ProviderSet,
// 业务层 ProviderSets
repository.ProviderSet,
service.ProviderSet,
handler.ProviderSet,
// 服务器层 ProviderSet
server.ProviderSet,
// 清理函数提供者
provideCleanup,
// 应用程序结构体
wire.Struct(new(Application), "Server", "Cleanup"),
)
return nil, nil
}
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Cleanup steps in reverse dependency order
cleanupSteps := []struct {
name string
fn func() error
}{
{"PricingService", func() error {
services.Pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},
{"Database", func() error {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}},
}
for _, step := range cleanupSteps {
if err := step.fn(); err != nil {
log.Printf("[Cleanup] %s failed: %v", step.name, err)
// Continue with remaining cleanup steps even if one fails
} else {
log.Printf("[Cleanup] %s succeeded", step.name)
}
}
// Check if context timed out
select {
case <-ctx.Done():
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
default:
log.Printf("[Cleanup] All cleanup steps completed")
}
}
}

View File

@@ -0,0 +1,201 @@
// Code generated by Wire. DO NOT EDIT.
//go:generate go run -mod=mod github.com/google/wire/cmd/wire
//go:build !wireinject
// +build !wireinject
package main
import (
"context"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
"log"
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/handler/admin"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"time"
)
import (
_ "embed"
)
// Injectors from wire.go:
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
configConfig, err := config.ProvideConfig()
if err != nil {
return nil, err
}
db, err := infrastructure.ProvideDB(configConfig)
if err != nil {
return nil, err
}
userRepository := repository.NewUserRepository(db)
settingRepository := repository.NewSettingRepository(db)
settingService := service.NewSettingService(settingRepository, configConfig)
client := infrastructure.ProvideRedis(configConfig)
emailService := service.NewEmailService(settingRepository, client)
turnstileService := service.NewTurnstileService(settingService)
emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService)
userService := service.NewUserService(userRepository, configConfig)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, client, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository)
subscriptionService := service.NewSubscriptionService(repositories, billingCacheService)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
adminService := service.NewAdminService(repositories, billingCacheService)
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
oAuthService := service.NewOAuthService(proxyRepository)
rateLimitService := service.NewRateLimitService(repositories, configConfig)
accountUsageService := service.NewAccountUsageService(repositories, oAuthService)
accountTestService := service.NewAccountTestService(repositories, oAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService)
systemHandler := handler.ProvideSystemHandler(client, buildInfo)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
pricingService, err := service.ProvidePricingService(configConfig)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(client)
gatewayService := service.NewGatewayService(repositories, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
concurrencyService := service.NewConcurrencyService(client)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
groupService := service.NewGroupService(groupRepository)
accountService := service.NewAccountService(accountRepository, groupRepository)
proxyService := service.NewProxyService(proxyRepository)
services := &service.Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oAuthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
}
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
httpServer := server.ProvideHTTPServer(configConfig, engine)
v := provideCleanup(db, client, services)
application := &Application{
Server: httpServer,
Cleanup: v,
}
return application, nil
}
// wire.go:
type Application struct {
Server *http.Server
Cleanup func()
}
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cleanupSteps := []struct {
name string
fn func() error
}{
{"PricingService", func() error {
services.Pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},
{"Database", func() error {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}},
}
for _, step := range cleanupSteps {
if err := step.fn(); err != nil {
log.Printf("[Cleanup] %s failed: %v", step.name, err)
} else {
log.Printf("[Cleanup] %s succeeded", step.name)
}
}
select {
case <-ctx.Done():
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
default:
log.Printf("[Cleanup] All cleanup steps completed")
}
}
}

View File

@@ -13,6 +13,7 @@ require (
github.com/redis/go-redis/v9 v9.3.0 github.com/redis/go-redis/v9 v9.3.0
github.com/spf13/viper v1.18.2 github.com/spf13/viper v1.18.2
golang.org/x/crypto v0.44.0 golang.org/x/crypto v0.44.0
golang.org/x/net v0.47.0
golang.org/x/term v0.37.0 golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.4 gorm.io/driver/postgres v1.5.4
@@ -33,6 +34,8 @@ require (
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/google/wire v0.7.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -50,6 +53,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect github.com/refraction-networking/utls v1.8.1 // indirect
@@ -66,9 +70,11 @@ require (
go.uber.org/multierr v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/net v0.47.0 // indirect golang.org/x/mod v0.29.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
) )

View File

@@ -48,8 +48,12 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
@@ -154,8 +158,12 @@ golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
@@ -166,6 +174,8 @@ golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=

View File

@@ -203,3 +203,29 @@ func (c *Config) Validate() error {
} }
return nil return nil
} }
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.
// Priority: config.yaml > environment variables > defaults
func GetServerAddress() string {
v := viper.New()
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
v.AddConfigPath("/etc/sub2api")
// Support SERVER_HOST and SERVER_PORT environment variables
v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("server.port", 8080)
// Try to read config file (ignore errors if not found)
_ = v.ReadInConfig()
host := v.GetString("server.host")
port := v.GetInt("server.port")
return fmt.Sprintf("%s:%d", host, port)
}

View File

@@ -0,0 +1,13 @@
package config
import "github.com/google/wire"
// ProviderSet 提供配置层的依赖
var ProviderSet = wire.NewSet(
ProvideConfig,
)
// ProvideConfig 提供应用配置
func ProvideConfig() (*Config, error) {
return Load()
}

View File

@@ -3,6 +3,7 @@ package admin
import ( import (
"strconv" "strconv"
"sub2api/internal/pkg/claude"
"sub2api/internal/pkg/response" "sub2api/internal/pkg/response"
"sub2api/internal/service" "sub2api/internal/service"
@@ -186,6 +187,11 @@ func (h *AccountHandler) Delete(c *gin.Context) {
response.Success(c, gin.H{"message": "Account deleted successfully"}) response.Success(c, gin.H{"message": "Account deleted successfully"})
} }
// TestAccountRequest represents the request body for testing an account
type TestAccountRequest struct {
ModelID string `json:"model_id"`
}
// Test handles testing account connectivity with SSE streaming // Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test // POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) { func (h *AccountHandler) Test(c *gin.Context) {
@@ -195,8 +201,12 @@ func (h *AccountHandler) Test(c *gin.Context) {
return return
} }
var req TestAccountRequest
// Allow empty body, model_id is optional
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming // Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID); err != nil { if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
// Error already sent via SSE, just log // Error already sent via SSE, just log
return return
} }
@@ -231,16 +241,20 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return return
} }
// Update account credentials // Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
newCredentials := map[string]interface{}{ newCredentials := make(map[string]interface{})
"access_token": tokenInfo.AccessToken, for k, v := range account.Credentials {
"token_type": tokenInfo.TokenType, newCredentials[k] = v
"expires_in": tokenInfo.ExpiresIn,
"expires_at": tokenInfo.ExpiresAt,
"refresh_token": tokenInfo.RefreshToken,
"scope": tokenInfo.Scope,
} }
// Update token-related fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
newCredentials["refresh_token"] = tokenInfo.RefreshToken
newCredentials["scope"] = tokenInfo.Scope
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials, Credentials: newCredentials,
}) })
@@ -535,3 +549,58 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
response.Success(c, account) response.Success(c, account)
} }
// GetAvailableModels handles getting available models for an account
// GET /api/v1/admin/accounts/:id/models
func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
response.Success(c, claude.DefaultModels)
return
}
// For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
// No mapping configured, return default models
response.Success(c, claude.DefaultModels)
return
}
// Return mapped models (keys of the mapping are the available model IDs)
var models []claude.Model
for requestedModel := range mapping {
// Try to find display info from default models
var found bool
for _, dm := range claude.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
// If not found in defaults, create a basic entry
if !found {
models = append(models, claude.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
CreatedAt: "",
})
}
}
response.Success(c, models)
}

View File

@@ -2,6 +2,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"sub2api/internal/pkg/response" "sub2api/internal/pkg/response"
"sub2api/internal/service" "sub2api/internal/service"
@@ -112,12 +113,12 @@ func (h *ProxyHandler) Create(c *gin.Context) {
} }
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: req.Name, Name: strings.TrimSpace(req.Name),
Protocol: req.Protocol, Protocol: strings.TrimSpace(req.Protocol),
Host: req.Host, Host: strings.TrimSpace(req.Host),
Port: req.Port, Port: req.Port,
Username: req.Username, Username: strings.TrimSpace(req.Username),
Password: req.Password, Password: strings.TrimSpace(req.Password),
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error()) response.BadRequest(c, "Failed to create proxy: "+err.Error())
@@ -143,13 +144,13 @@ func (h *ProxyHandler) Update(c *gin.Context) {
} }
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{ proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
Name: req.Name, Name: strings.TrimSpace(req.Name),
Protocol: req.Protocol, Protocol: strings.TrimSpace(req.Protocol),
Host: req.Host, Host: strings.TrimSpace(req.Host),
Port: req.Port, Port: req.Port,
Username: req.Username, Username: strings.TrimSpace(req.Username),
Password: req.Password, Password: strings.TrimSpace(req.Password),
Status: req.Status, Status: strings.TrimSpace(req.Status),
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error()) response.InternalError(c, "Failed to update proxy: "+err.Error())
@@ -263,8 +264,14 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
skipped := 0 skipped := 0
for _, item := range req.Proxies { for _, item := range req.Proxies {
// Trim all string fields
host := strings.TrimSpace(item.Host)
protocol := strings.TrimSpace(item.Protocol)
username := strings.TrimSpace(item.Username)
password := strings.TrimSpace(item.Password)
// Check for duplicates (same host, port, username, password) // Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), item.Host, item.Port, item.Username, item.Password) exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil { if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error()) response.InternalError(c, "Failed to check proxy existence: "+err.Error())
return return
@@ -278,11 +285,11 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
// Create proxy with default name // Create proxy with default name
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ _, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: "default", Name: "default",
Protocol: item.Protocol, Protocol: protocol,
Host: item.Host, Host: host,
Port: item.Port, Port: item.Port,
Username: item.Username, Username: username,
Password: item.Password, Password: password,
}) })
if err != nil { if err != nil {
// If creation fails due to duplicate, count as skipped // If creation fails due to duplicate, count as skipped

View File

@@ -7,10 +7,12 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"strings"
"time" "time"
"sub2api/internal/middleware" "sub2api/internal/middleware"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/claude"
"sub2api/internal/service" "sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -126,6 +128,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
} else {
sendMockWarmupResponse(c, req.Model)
}
return
}
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted) accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
if err != nil { if err != nil {
@@ -285,29 +297,8 @@ func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id
// Models handles listing available models // Models handles listing available models
// GET /v1/models // GET /v1/models
func (h *GatewayHandler) Models(c *gin.Context) { func (h *GatewayHandler) Models(c *gin.Context) {
models := []gin.H{
{
"id": "claude-opus-4-5-20251101",
"type": "model",
"display_name": "Claude Opus 4.5",
"created_at": "2025-11-01T00:00:00Z",
},
{
"id": "claude-sonnet-4-5-20250929",
"type": "model",
"display_name": "Claude Sonnet 4.5",
"created_at": "2025-09-29T00:00:00Z",
},
{
"id": "claude-haiku-4-5-20251001",
"type": "model",
"display_name": "Claude Haiku 4.5",
"created_at": "2025-10-01T00:00:00Z",
},
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"data": models, "data": claude.DefaultModels,
"object": "list", "object": "list",
}) })
} }
@@ -443,3 +434,155 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
}, },
}) })
} }
// CountTokens handles token counting endpoint
// POST /v1/messages/count_tokens
// 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// 解析请求获取模型名
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 获取订阅信息可能为nil
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
return
}
// 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理
return
}
}
// isWarmupRequest 检测是否为预热请求标题生成、Warmup等
func isWarmupRequest(body []byte) bool {
// 快速检查如果body不包含关键字直接返回false
bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
return false
}
// 解析完整请求
var req struct {
Messages []struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
} `json:"messages"`
System []struct {
Text string `json:"text"`
} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
return false
}
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return true
}
}
}
}
// 检查 system 中的标题提取模式
for _, system := range req.System {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return true
}
}
return false
}
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
events := []string{
`event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`,
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
}
for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush()
time.Sleep(20 * time.Millisecond)
}
}
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) {
c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup",
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,
"output_tokens": 2,
},
})
}

View File

@@ -2,10 +2,6 @@ package handler
import ( import (
"sub2api/internal/handler/admin" "sub2api/internal/handler/admin"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/redis/go-redis/v9"
) )
// AdminHandlers contains all admin-related HTTP handlers // AdminHandlers contains all admin-related HTTP handlers
@@ -41,30 +37,3 @@ type BuildInfo struct {
Version string Version string
BuildType string // "source" for manual builds, "release" for CI builds BuildType string // "source" for manual builds, "release" for CI builds
} }
// NewHandlers creates a new Handlers instance with all handlers initialized
func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers {
return &Handlers{
Auth: NewAuthHandler(services.Auth),
User: NewUserHandler(services.User),
APIKey: NewAPIKeyHandler(services.ApiKey),
Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey),
Redeem: NewRedeemHandler(services.Redeem),
Subscription: NewSubscriptionHandler(services.Subscription),
Admin: &AdminHandlers{
Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog),
User: admin.NewUserHandler(services.Admin),
Group: admin.NewGroupHandler(services.Admin),
Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest),
OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin),
Proxy: admin.NewProxyHandler(services.Admin),
Redeem: admin.NewRedeemHandler(services.Admin),
Setting: admin.NewSettingHandler(services.Setting, services.Email),
System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType),
Subscription: admin.NewSubscriptionHandler(services.Subscription),
Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin),
},
Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache),
Setting: NewSettingHandler(services.Setting, buildInfo.Version),
}
}

View File

@@ -0,0 +1,103 @@
package handler
import (
"sub2api/internal/handler/admin"
"sub2api/internal/service"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
)
// ProvideAdminHandlers creates the AdminHandlers struct
func ProvideAdminHandlers(
dashboardHandler *admin.DashboardHandler,
userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
oauthHandler *admin.OAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
settingHandler *admin.SettingHandler,
systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
User: userHandler,
Group: groupHandler,
Account: accountHandler,
OAuth: oauthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Setting: settingHandler,
System: systemHandler,
Subscription: subscriptionHandler,
Usage: usageHandler,
}
}
// ProvideSystemHandler creates admin.SystemHandler with BuildInfo parameters
func ProvideSystemHandler(rdb *redis.Client, buildInfo BuildInfo) *admin.SystemHandler {
return admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType)
}
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
return NewSettingHandler(settingService, buildInfo.Version)
}
// ProvideHandlers creates the Handlers struct
func ProvideHandlers(
authHandler *AuthHandler,
userHandler *UserHandler,
apiKeyHandler *APIKeyHandler,
usageHandler *UsageHandler,
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
settingHandler *SettingHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
User: userHandler,
APIKey: apiKeyHandler,
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
Setting: settingHandler,
}
}
// ProviderSet is the Wire provider set for all handlers
var ProviderSet = wire.NewSet(
// Top-level handlers
NewAuthHandler,
NewUserHandler,
NewAPIKeyHandler,
NewUsageHandler,
NewRedeemHandler,
NewSubscriptionHandler,
NewGatewayHandler,
ProvideSettingHandler,
// Admin handlers
admin.NewDashboardHandler,
admin.NewUserHandler,
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewSettingHandler,
ProvideSystemHandler,
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
ProvideHandlers,
)

View File

@@ -0,0 +1,38 @@
package infrastructure
import (
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// InitDB 初始化数据库连接
func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 初始化时区(在数据库连接之前,确保时区设置正确)
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, err
}
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,16 @@
package infrastructure
import (
"sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}

View File

@@ -0,0 +1,25 @@
package infrastructure
import (
"sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
// ProviderSet 提供基础设施层的依赖
var ProviderSet = wire.NewSet(
ProvideDB,
ProvideRedis,
)
// ProvideDB 提供数据库连接
func ProvideDB(cfg *config.Config) (*gorm.DB, error) {
return InitDB(cfg)
}
// ProvideRedis 提供 Redis 客户端
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}

View File

@@ -263,3 +263,17 @@ func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
} }
return false return false
} }
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
// 启用后标题生成、Warmup等预热请求将返回mock响应不消耗上游token
func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil {
return false
}
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}

View File

@@ -0,0 +1,74 @@
package claude
// Claude Code 客户端相关常量
// Beta header 常量
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0",
"X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
// Model 表示一个 Claude 模型
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels Claude Code 客户端支持的默认模型列表
var DefaultModels = []Model{
{
ID: "claude-opus-4-5-20251101",
Type: "model",
DisplayName: "Claude Opus 4.5",
CreatedAt: "2025-11-01T00:00:00Z",
},
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",
DisplayName: "Claude Sonnet 4.5",
CreatedAt: "2025-09-29T00:00:00Z",
},
{
ID: "claude-haiku-4-5-20251001",
Type: "model",
DisplayName: "Claude Haiku 4.5",
CreatedAt: "2025-10-01T00:00:00Z",
},
}
// DefaultModelIDs 返回默认模型的 ID 列表
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"

View File

@@ -1,9 +1,5 @@
package repository package repository
import (
"gorm.io/gorm"
)
// Repositories 所有仓库的集合 // Repositories 所有仓库的集合
type Repositories struct { type Repositories struct {
User *UserRepository User *UserRepository
@@ -17,21 +13,6 @@ type Repositories struct {
UserSubscription *UserSubscriptionRepository UserSubscription *UserSubscriptionRepository
} }
// NewRepositories 创建所有仓库
func NewRepositories(db *gorm.DB) *Repositories {
return &Repositories{
User: NewUserRepository(db),
ApiKey: NewApiKeyRepository(db),
Group: NewGroupRepository(db),
Account: NewAccountRepository(db),
Proxy: NewProxyRepository(db),
RedeemCode: NewRedeemCodeRepository(db),
UsageLog: NewUsageLogRepository(db),
Setting: NewSettingRepository(db),
UserSubscription: NewUserSubscriptionRepository(db),
}
}
// PaginationParams 分页参数 // PaginationParams 分页参数
type PaginationParams struct { type PaginationParams struct {
Page int Page int

View File

@@ -0,0 +1,19 @@
package repository
import (
"github.com/google/wire"
)
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewApiKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewProxyRepository,
NewRedeemCodeRepository,
NewUsageLogRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
wire.Struct(new(Repositories), "*"),
)

View File

@@ -0,0 +1,45 @@
package server
import (
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/repository"
"sub2api/internal/service"
"time"
"github.com/gin-gonic/gin"
"github.com/google/wire"
)
// ProviderSet 提供服务器层的依赖
var ProviderSet = wire.NewSet(
ProvideRouter,
ProvideHTTPServer,
)
// ProvideRouter 提供路由器
func ProvideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
}
r := gin.New()
r.Use(gin.Recovery())
return SetupRouter(r, cfg, handlers, services, repos)
}
// ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
return &http.Server{
Addr: cfg.Server.Address(),
Handler: router,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
}

View File

@@ -0,0 +1,289 @@
package server
import (
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/middleware"
"sub2api/internal/repository"
"sub2api/internal/service"
"sub2api/internal/web"
"github.com/gin-gonic/gin"
)
// SetupRouter 配置路由器中间件和路由
func SetupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
// 应用中间件
r.Use(middleware.Logger())
r.Use(middleware.CORS())
// 注册路由
registerRoutes(r, handlers, services, repos)
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
}
return r
}
// registerRoutes 注册所有 HTTP 路由
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Claude Code 遥测日志忽略直接返回200
r.POST("/api/event_logging/batch", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
// API v1
v1 := r.Group("/api/v1")
{
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的接口
authenticated := v1.Group("")
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
{
// 当前用户信息
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
// 管理员接口
admin := v1.Group("/admin")
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
{
// 仪表盘
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
// 用户管理
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
// 分组管理
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
// 账号管理
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
// OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
// 卡密管理
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
// 系统设置
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
}
// 系统管理
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
// 订阅管理
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
// 使用记录管理
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
}
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
}
}

View File

@@ -15,6 +15,7 @@ import (
"strings" "strings"
"time" "time"
"sub2api/internal/pkg/claude"
"sub2api/internal/repository" "sub2api/internal/repository"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -23,7 +24,6 @@ import (
const ( const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages" testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testModel = "claude-sonnet-4-5-20250929"
) )
// TestEvent represents a SSE event for account testing // TestEvent represents a SSE event for account testing
@@ -62,10 +62,10 @@ func generateSessionString() string {
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID) return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
} }
// createTestPayload creates a minimal test request payload for OAuth/Setup Token accounts // createTestPayload creates a Claude Code style test request payload
func createTestPayload() map[string]interface{} { func createTestPayload(modelID string) map[string]interface{} {
return map[string]interface{}{ return map[string]interface{}{
"model": testModel, "model": modelID,
"messages": []map[string]interface{}{ "messages": []map[string]interface{}{
{ {
"role": "user", "role": "user",
@@ -92,29 +92,16 @@ func createTestPayload() map[string]interface{} {
"metadata": map[string]string{ "metadata": map[string]string{
"user_id": generateSessionString(), "user_id": generateSessionString(),
}, },
"max_tokens": 1024, "max_tokens": 1024,
"temperature": 1, "temperature": 1,
"stream": true, "stream": true,
}
}
// createApiKeyTestPayload creates a simpler test request payload for API Key accounts
func createApiKeyTestPayload(model string) map[string]interface{} {
return map[string]interface{}{
"model": model,
"messages": []map[string]interface{}{
{
"role": "user",
"content": "hi",
},
},
"max_tokens": 1024,
"stream": true,
} }
} }
// TestAccountConnection tests an account's connection by sending a test request // TestAccountConnection tests an account's connection by sending a test request
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64) error { // All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Get account // Get account
@@ -123,14 +110,30 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.sendErrorAndEnd(c, "Account not found") return s.sendErrorAndEnd(c, "Account not found")
} }
// Determine authentication method based on account type // Determine the model to use
testModelID := modelID
if testModelID == "" {
testModelID = claude.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if mapping != nil && len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
}
// Determine authentication method and API URL
var authToken string var authToken string
var authType string // "bearer" for OAuth, "apikey" for API Key var useBearer bool
var apiURL string var apiURL string
if account.IsOAuth() { if account.IsOAuth() {
// OAuth or Setup Token account // OAuth or Setup Token - use Bearer token
authType = "bearer" useBearer = true
apiURL = testClaudeAPIURL apiURL = testClaudeAPIURL
authToken = account.GetCredential("access_token") authToken = account.GetCredential("access_token")
if authToken == "" { if authToken == "" {
@@ -141,7 +144,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
needRefresh := false needRefresh := false
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" { if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+300 > expiresAt { // 5 minute buffer if err == nil && time.Now().Unix()+300 > expiresAt {
needRefresh = true needRefresh = true
} }
} }
@@ -154,19 +157,17 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
authToken = tokenInfo.AccessToken authToken = tokenInfo.AccessToken
} }
} else if account.Type == "apikey" { } else if account.Type == "apikey" {
// API Key account // API Key - use x-api-key header
authType = "apikey" useBearer = false
authToken = account.GetCredential("api_key") authToken = account.GetCredential("api_key")
if authToken == "" { if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available") return s.sendErrorAndEnd(c, "No API key available")
} }
// Get base URL (use default if not set)
apiURL = account.GetBaseURL() apiURL = account.GetBaseURL()
if apiURL == "" { if apiURL == "" {
apiURL = "https://api.anthropic.com" apiURL = "https://api.anthropic.com"
} }
// Append /v1/messages endpoint
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages" apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
@@ -179,37 +180,32 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
c.Writer.Header().Set("X-Accel-Buffering", "no") c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush() c.Writer.Flush()
// Create test request payload // Create Claude Code style payload (same for all account types)
var payload map[string]interface{} payload := createTestPayload(testModelID)
var actualModel string
if authType == "apikey" {
// Use simpler payload for API Key (without Claude Code specific fields)
// Apply model mapping if configured
actualModel = account.GetMappedModel(testModel)
payload = createApiKeyTestPayload(actualModel)
} else {
actualModel = testModel
payload = createTestPayload()
}
payloadBytes, _ := json.Marshal(payload) payloadBytes, _ := json.Marshal(payload)
// Send test_start event with model info // Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: actualModel}) s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes)) req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request") return s.sendErrorAndEnd(c, "Failed to create request")
} }
// Set headers based on auth type // Set common headers
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
if authType == "bearer" { // Apply Claude Code client headers
for key, value := range claude.DefaultHeaders {
req.Header.Set(key, value)
}
// Set authentication header
if useBearer {
req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31,interleaved-thinking-2025-05-14,output-128k-2025-02-19")
} else { } else {
// API Key uses x-api-key header
req.Header.Set("x-api-key", authToken) req.Header.Set("x-api-key", authToken)
} }
@@ -252,7 +248,6 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
line, err := reader.ReadString('\n') line, err := reader.ReadString('\n')
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
// Stream ended, send complete event
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil return nil
} }
@@ -310,5 +305,5 @@ func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error { func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
log.Printf("Account test error: %s", errorMsg) log.Printf("Account test error: %s", errorMsg)
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg}) s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
return fmt.Errorf(errorMsg) return fmt.Errorf("%s", errorMsg)
} }

View File

@@ -191,24 +191,17 @@ type adminServiceImpl struct {
} }
// NewAdminService creates a new AdminService // NewAdminService creates a new AdminService
func NewAdminService(repos *repository.Repositories) AdminService { func NewAdminService(repos *repository.Repositories, billingCacheService *BillingCacheService) AdminService {
return &adminServiceImpl{ return &adminServiceImpl{
userRepo: repos.User, userRepo: repos.User,
groupRepo: repos.Group, groupRepo: repos.Group,
accountRepo: repos.Account, accountRepo: repos.Account,
proxyRepo: repos.Proxy, proxyRepo: repos.Proxy,
apiKeyRepo: repos.ApiKey, apiKeyRepo: repos.ApiKey,
redeemCodeRepo: repos.RedeemCode, redeemCodeRepo: repos.RedeemCode,
usageLogRepo: repos.UsageLog, usageLogRepo: repos.UsageLog,
userSubRepo: repos.UserSubscription, userSubRepo: repos.UserSubscription,
} billingCacheService: billingCacheService,
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
// 注意AdminService是接口需要类型断言
func SetAdminServiceBillingCache(adminService AdminService, billingCacheService *BillingCacheService) {
if impl, ok := adminService.(*adminServiceImpl); ok {
impl.billingCacheService = billingCacheService
} }
} }

View File

@@ -16,13 +16,13 @@ import (
) )
var ( var (
ErrInvalidCredentials = errors.New("invalid email or password") ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotActive = errors.New("user is not active") ErrUserNotActive = errors.New("user is not active")
ErrEmailExists = errors.New("email already exists") ErrEmailExists = errors.New("email already exists")
ErrInvalidToken = errors.New("invalid token") ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token has expired") ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required") ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled") ErrRegDisabled = errors.New("registration is currently disabled")
) )
// JWTClaims JWT载荷数据 // JWTClaims JWT载荷数据
@@ -44,33 +44,24 @@ type AuthService struct {
} }
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
func NewAuthService(userRepo *repository.UserRepository, cfg *config.Config) *AuthService { func NewAuthService(
userRepo *repository.UserRepository,
cfg *config.Config,
settingService *SettingService,
emailService *EmailService,
turnstileService *TurnstileService,
emailQueueService *EmailQueueService,
) *AuthService {
return &AuthService{ return &AuthService{
userRepo: userRepo, userRepo: userRepo,
cfg: cfg, cfg: cfg,
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
emailQueueService: emailQueueService,
} }
} }
// SetSettingService 设置系统设置服务(用于检查注册开关和邮件验证)
func (s *AuthService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
}
// SetEmailService 设置邮件服务(用于邮件验证)
func (s *AuthService) SetEmailService(emailService *EmailService) {
s.emailService = emailService
}
// SetTurnstileService 设置Turnstile服务用于验证码校验
func (s *AuthService) SetTurnstileService(turnstileService *TurnstileService) {
s.turnstileService = turnstileService
}
// SetEmailQueueService 设置邮件队列服务(用于异步发送邮件)
func (s *AuthService) SetEmailQueueService(emailQueueService *EmailQueueService) {
s.emailQueueService = emailQueueService
}
// Register 用户注册返回token和用户 // Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) { func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
return s.RegisterWithVerification(ctx, email, password, "") return s.RegisterWithVerification(ctx, email, password, "")

View File

@@ -20,6 +20,7 @@ import (
"sub2api/internal/config" "sub2api/internal/config"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/claude"
"sub2api/internal/repository" "sub2api/internal/repository"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -27,10 +28,11 @@ import (
) )
const ( const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
stickySessionPrefix = "sticky_session:" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL stickySessionPrefix = "sticky_session:"
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token stickySessionTTL = time.Hour // 粘性会话TTL
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
) )
// allowedHeaders 白名单headers参考CRS项目 // allowedHeaders 白名单headers参考CRS项目
@@ -601,13 +603,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// getBetaHeader 处理anthropic-beta header // getBetaHeader 处理anthropic-beta header
// 对于OAuth账号需要确保包含oauth-2025-04-20 // 对于OAuth账号需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string { func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string {
const oauthBeta = "oauth-2025-04-20"
const claudeCodeBeta = "claude-code-20250219"
// 如果客户端传了anthropic-beta // 如果客户端传了anthropic-beta
if clientBetaHeader != "" { if clientBetaHeader != "" {
// 已包含oauth beta则直接返回 // 已包含oauth beta则直接返回
if strings.Contains(clientBetaHeader, oauthBeta) { if strings.Contains(clientBetaHeader, claude.BetaOAuth) {
return clientBetaHeader return clientBetaHeader
} }
@@ -620,7 +619,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
// 在claude-code-20250219后面插入oauth beta // 在claude-code-20250219后面插入oauth beta
claudeCodeIdx := -1 claudeCodeIdx := -1
for i, p := range parts { for i, p := range parts {
if p == claudeCodeBeta { if p == claude.BetaClaudeCode {
claudeCodeIdx = i claudeCodeIdx = i
break break
} }
@@ -630,13 +629,13 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
// 在claude-code后面插入 // 在claude-code后面插入
newParts := make([]string, 0, len(parts)+1) newParts := make([]string, 0, len(parts)+1)
newParts = append(newParts, parts[:claudeCodeIdx+1]...) newParts = append(newParts, parts[:claudeCodeIdx+1]...)
newParts = append(newParts, oauthBeta) newParts = append(newParts, claude.BetaOAuth)
newParts = append(newParts, parts[claudeCodeIdx+1:]...) newParts = append(newParts, parts[claudeCodeIdx+1:]...)
return strings.Join(newParts, ",") return strings.Join(newParts, ",")
} }
// 没有claude-code放在第一位 // 没有claude-code放在第一位
return oauthBeta + "," + clientBetaHeader return claude.BetaOAuth + "," + clientBetaHeader
} }
// 客户端没传,根据模型生成 // 客户端没传,根据模型生成
@@ -650,10 +649,10 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
// haiku模型不需要claude-code beta // haiku模型不需要claude-code beta
if strings.Contains(strings.ToLower(modelID), "haiku") { if strings.Contains(strings.ToLower(modelID), "haiku") {
return "oauth-2025-04-20,interleaved-thinking-2025-05-14" return claude.HaikuBetaHeader
} }
return "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14" return claude.DefaultBetaHeader
} }
func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) { func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) {
@@ -1044,3 +1043,205 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil return nil
} }
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error {
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == model.AccountTypeApiKey {
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
body = s.replaceModelInBody(body, mappedModel)
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
}
}
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token")
return err
}
// 构建上游请求
upstreamResult, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
}
// 选择 HTTP client
httpClient := s.httpClient
if upstreamResult.Client != nil {
httpClient = upstreamResult.Client
}
// 发送请求
resp, err := httpClient.Do(upstreamResult.Request)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
}
defer resp.Body.Close()
// 处理 401 错误:刷新 token 重试(仅 OAuth
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
resp.Body.Close()
token, tokenType, err = s.forceRefreshToken(ctx, account)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed")
return fmt.Errorf("token refresh failed: %w", err)
}
upstreamResult, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
if err != nil {
return err
}
httpClient = s.httpClient
if upstreamResult.Client != nil {
httpClient = upstreamResult.Client
}
resp, err = httpClient.Do(upstreamResult.Request)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed")
return fmt.Errorf("retry request failed: %w", err)
}
defer resp.Body.Close()
}
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
// 处理错误响应
if resp.StatusCode >= 400 {
// 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 返回简化的错误响应
errMsg := "Upstream request failed"
switch resp.StatusCode {
case 429:
errMsg = "Rate limit exceeded"
case 529:
errMsg = "Service overloaded"
}
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
return fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// 透传成功响应
c.Data(resp.StatusCode, "application/json", respBody)
return nil
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == model.AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
// OAuth 账号:应用统一指纹和重写 userID
if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
}
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("Authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
// 白名单透传 headers
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
}
}
// 确保必要的 headers 存在
if req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
}
// 配置代理
var customClient *http.Client
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
customClient = &http.Client{Transport: transport}
}
}
}
return &buildUpstreamRequestResult{
Request: req,
Client: customClient,
}, nil
}
// countTokensError 返回 count_tokens 错误响应
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -57,20 +57,22 @@ type RedeemService struct {
} }
// NewRedeemService 创建兑换码服务实例 // NewRedeemService 创建兑换码服务实例
func NewRedeemService(redeemRepo *repository.RedeemCodeRepository, userRepo *repository.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client) *RedeemService { func NewRedeemService(
redeemRepo *repository.RedeemCodeRepository,
userRepo *repository.UserRepository,
subscriptionService *SubscriptionService,
rdb *redis.Client,
billingCacheService *BillingCacheService,
) *RedeemService {
return &RedeemService{ return &RedeemService{
redeemRepo: redeemRepo, redeemRepo: redeemRepo,
userRepo: userRepo, userRepo: userRepo,
subscriptionService: subscriptionService, subscriptionService: subscriptionService,
rdb: rdb, rdb: rdb,
billingCacheService: billingCacheService,
} }
} }
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *RedeemService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
}
// GenerateRandomCode 生成随机兑换码 // GenerateRandomCode 生成随机兑换码
func (s *RedeemService) GenerateRandomCode() (string, error) { func (s *RedeemService) GenerateRandomCode() (string, error) {
// 生成16字节随机数据 // 生成16字节随机数据

View File

@@ -1,12 +1,5 @@
package service package service
import (
"sub2api/internal/config"
"sub2api/internal/repository"
"github.com/redis/go-redis/v9"
)
// Services 服务集合容器 // Services 服务集合容器
type Services struct { type Services struct {
Auth *AuthService Auth *AuthService
@@ -34,106 +27,3 @@ type Services struct {
Concurrency *ConcurrencyService Concurrency *ConcurrencyService
Identity *IdentityService Identity *IdentityService
} }
// NewServices 创建所有服务实例
func NewServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *Services {
// 初始化价格服务
pricingService := NewPricingService(cfg)
if err := pricingService.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error())
}
// 初始化计费服务(依赖价格服务)
billingService := NewBillingService(cfg, pricingService)
// 初始化其他服务
authService := NewAuthService(repos.User, cfg)
userService := NewUserService(repos.User, cfg)
apiKeyService := NewApiKeyService(repos.ApiKey, repos.User, repos.Group, repos.UserSubscription, rdb, cfg)
groupService := NewGroupService(repos.Group)
accountService := NewAccountService(repos.Account, repos.Group)
proxyService := NewProxyService(repos.Proxy)
usageService := NewUsageService(repos.UsageLog, repos.User)
// 初始化订阅服务 (RedeemService 依赖)
subscriptionService := NewSubscriptionService(repos)
// 初始化兑换服务 (依赖订阅服务)
redeemService := NewRedeemService(repos.RedeemCode, repos.User, subscriptionService, rdb)
// 初始化Admin服务
adminService := NewAdminService(repos)
// 初始化OAuth服务GatewayService依赖
oauthService := NewOAuthService(repos.Proxy)
// 初始化限流服务
rateLimitService := NewRateLimitService(repos, cfg)
// 初始化计费缓存服务
billingCacheService := NewBillingCacheService(rdb, repos.User, repos.UserSubscription)
// 初始化账号使用量服务
accountUsageService := NewAccountUsageService(repos, oauthService)
// 初始化账号测试服务
accountTestService := NewAccountTestService(repos, oauthService)
// 初始化身份指纹服务
identityService := NewIdentityService(rdb)
// 初始化Gateway服务
gatewayService := NewGatewayService(repos, rdb, cfg, oauthService, billingService, rateLimitService, billingCacheService, identityService)
// 初始化设置服务
settingService := NewSettingService(repos.Setting, cfg)
emailService := NewEmailService(repos.Setting, rdb)
// 初始化邮件队列服务
emailQueueService := NewEmailQueueService(emailService, 3)
// 初始化Turnstile服务
turnstileService := NewTurnstileService(settingService)
// 设置Auth服务的依赖用于注册开关和邮件验证
authService.SetSettingService(settingService)
authService.SetEmailService(emailService)
authService.SetTurnstileService(turnstileService)
authService.SetEmailQueueService(emailQueueService)
// 初始化并发控制服务
concurrencyService := NewConcurrencyService(rdb)
// 注入计费缓存服务到需要失效缓存的服务
redeemService.SetBillingCacheService(billingCacheService)
subscriptionService.SetBillingCacheService(billingCacheService)
SetAdminServiceBillingCache(adminService, billingCacheService)
return &Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oauthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
}
}

View File

@@ -28,13 +28,11 @@ type SubscriptionService struct {
} }
// NewSubscriptionService 创建订阅服务 // NewSubscriptionService 创建订阅服务
func NewSubscriptionService(repos *repository.Repositories) *SubscriptionService { func NewSubscriptionService(repos *repository.Repositories, billingCacheService *BillingCacheService) *SubscriptionService {
return &SubscriptionService{repos: repos} return &SubscriptionService{
} repos: repos,
billingCacheService: billingCacheService,
// SetBillingCacheService 设置计费缓存服务(用于缓存失效) }
func (s *SubscriptionService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
} }
// AssignSubscriptionInput 分配订阅输入 // AssignSubscriptionInput 分配订阅输入
@@ -88,6 +86,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
// 如果用户已有同分组的订阅: // 如果用户已有同分组的订阅:
// - 未过期:从当前过期时间累加天数 // - 未过期:从当前过期时间累加天数
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅 // - 已过期:从当前时间开始计算新的过期时间,并激活订阅
//
// 如果没有订阅:创建新订阅 // 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) { func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型 // 检查分组是否存在且为订阅类型
@@ -191,15 +190,15 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass
now := time.Now() now := time.Now()
sub := &model.UserSubscription{ sub := &model.UserSubscription{
UserID: input.UserID, UserID: input.UserID,
GroupID: input.GroupID, GroupID: input.GroupID,
StartsAt: now, StartsAt: now,
ExpiresAt: now.AddDate(0, 0, validityDays), ExpiresAt: now.AddDate(0, 0, validityDays),
Status: model.SubscriptionStatusActive, Status: model.SubscriptionStatusActive,
AssignedAt: now, AssignedAt: now,
Notes: input.Notes, Notes: input.Notes,
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
} }
// 只有当 AssignedBy > 0 时才设置0 表示系统分配,如兑换码) // 只有当 AssignedBy > 0 时才设置0 表示系统分配,如兑换码)
if input.AssignedBy > 0 { if input.AssignedBy > 0 {
@@ -225,17 +224,17 @@ type BulkAssignSubscriptionInput struct {
// BulkAssignResult 批量分配结果 // BulkAssignResult 批量分配结果
type BulkAssignResult struct { type BulkAssignResult struct {
SuccessCount int SuccessCount int
FailedCount int FailedCount int
Subscriptions []model.UserSubscription Subscriptions []model.UserSubscription
Errors []string Errors []string
} }
// BulkAssignSubscription 批量分配订阅 // BulkAssignSubscription 批量分配订阅
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) { func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
result := &BulkAssignResult{ result := &BulkAssignResult{
Subscriptions: make([]model.UserSubscription, 0), Subscriptions: make([]model.UserSubscription, 0),
Errors: make([]string, 0), Errors: make([]string, 0),
} }
for _, userID := range input.UserIDs { for _, userID := range input.UserIDs {
@@ -417,10 +416,10 @@ func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID in
// SubscriptionProgress 订阅进度 // SubscriptionProgress 订阅进度
type SubscriptionProgress struct { type SubscriptionProgress struct {
ID int64 `json:"id"` ID int64 `json:"id"`
GroupName string `json:"group_name"` GroupName string `json:"group_name"`
ExpiresAt time.Time `json:"expires_at"` ExpiresAt time.Time `json:"expires_at"`
ExpiresInDays int `json:"expires_in_days"` ExpiresInDays int `json:"expires_in_days"`
Daily *UsageWindowProgress `json:"daily,omitempty"` Daily *UsageWindowProgress `json:"daily,omitempty"`
Weekly *UsageWindowProgress `json:"weekly,omitempty"` Weekly *UsageWindowProgress `json:"weekly,omitempty"`
Monthly *UsageWindowProgress `json:"monthly,omitempty"` Monthly *UsageWindowProgress `json:"monthly,omitempty"`
@@ -428,13 +427,13 @@ type SubscriptionProgress struct {
// UsageWindowProgress 使用窗口进度 // UsageWindowProgress 使用窗口进度
type UsageWindowProgress struct { type UsageWindowProgress struct {
LimitUSD float64 `json:"limit_usd"` LimitUSD float64 `json:"limit_usd"`
UsedUSD float64 `json:"used_usd"` UsedUSD float64 `json:"used_usd"`
RemainingUSD float64 `json:"remaining_usd"` RemainingUSD float64 `json:"remaining_usd"`
Percentage float64 `json:"percentage"` Percentage float64 `json:"percentage"`
WindowStart time.Time `json:"window_start"` WindowStart time.Time `json:"window_start"`
ResetsAt time.Time `json:"resets_at"` ResetsAt time.Time `json:"resets_at"`
ResetsInSeconds int64 `json:"resets_in_seconds"` ResetsInSeconds int64 `json:"resets_in_seconds"`
} }
// GetSubscriptionProgress 获取订阅使用进度 // GetSubscriptionProgress 获取订阅使用进度
@@ -464,12 +463,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
limit := *group.DailyLimitUSD limit := *group.DailyLimitUSD
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour) resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
progress.Daily = &UsageWindowProgress{ progress.Daily = &UsageWindowProgress{
LimitUSD: limit, LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD, UsedUSD: sub.DailyUsageUSD,
RemainingUSD: limit - sub.DailyUsageUSD, RemainingUSD: limit - sub.DailyUsageUSD,
Percentage: (sub.DailyUsageUSD / limit) * 100, Percentage: (sub.DailyUsageUSD / limit) * 100,
WindowStart: *sub.DailyWindowStart, WindowStart: *sub.DailyWindowStart,
ResetsAt: resetsAt, ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()), ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
} }
if progress.Daily.RemainingUSD < 0 { if progress.Daily.RemainingUSD < 0 {
@@ -488,12 +487,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
limit := *group.WeeklyLimitUSD limit := *group.WeeklyLimitUSD
resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour) resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour)
progress.Weekly = &UsageWindowProgress{ progress.Weekly = &UsageWindowProgress{
LimitUSD: limit, LimitUSD: limit,
UsedUSD: sub.WeeklyUsageUSD, UsedUSD: sub.WeeklyUsageUSD,
RemainingUSD: limit - sub.WeeklyUsageUSD, RemainingUSD: limit - sub.WeeklyUsageUSD,
Percentage: (sub.WeeklyUsageUSD / limit) * 100, Percentage: (sub.WeeklyUsageUSD / limit) * 100,
WindowStart: *sub.WeeklyWindowStart, WindowStart: *sub.WeeklyWindowStart,
ResetsAt: resetsAt, ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()), ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
} }
if progress.Weekly.RemainingUSD < 0 { if progress.Weekly.RemainingUSD < 0 {
@@ -512,12 +511,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
limit := *group.MonthlyLimitUSD limit := *group.MonthlyLimitUSD
resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour) resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour)
progress.Monthly = &UsageWindowProgress{ progress.Monthly = &UsageWindowProgress{
LimitUSD: limit, LimitUSD: limit,
UsedUSD: sub.MonthlyUsageUSD, UsedUSD: sub.MonthlyUsageUSD,
RemainingUSD: limit - sub.MonthlyUsageUSD, RemainingUSD: limit - sub.MonthlyUsageUSD,
Percentage: (sub.MonthlyUsageUSD / limit) * 100, Percentage: (sub.MonthlyUsageUSD / limit) * 100,
WindowStart: *sub.MonthlyWindowStart, WindowStart: *sub.MonthlyWindowStart,
ResetsAt: resetsAt, ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()), ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
} }
if progress.Monthly.RemainingUSD < 0 { if progress.Monthly.RemainingUSD < 0 {

View File

@@ -0,0 +1,54 @@
package service
import (
"sub2api/internal/config"
"github.com/google/wire"
)
// ProvidePricingService creates and initializes PricingService
func ProvidePricingService(cfg *config.Config) (*PricingService, error) {
svc := NewPricingService(cfg)
if err := svc.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error())
}
return svc, nil
}
// ProvideEmailQueueService creates EmailQueueService with default worker count
func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
return NewEmailQueueService(emailService, 3)
}
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewApiKeyService,
NewGroupService,
NewAccountService,
NewProxyService,
NewRedeemService,
NewUsageService,
ProvidePricingService,
NewBillingService,
NewBillingCacheService,
NewAdminService,
NewGatewayService,
NewOAuthService,
NewRateLimitService,
NewAccountUsageService,
NewAccountTestService,
NewSettingService,
NewEmailService,
ProvideEmailQueueService,
NewTurnstileService,
NewSubscriptionService,
NewConcurrencyService,
NewIdentityService,
// Provide the Services container struct
wire.Struct(new(Services), "*"),
)

8
backend/tools.go Normal file
View File

@@ -0,0 +1,8 @@
//go:build tools
// +build tools
package tools
import (
_ "github.com/google/wire/cmd/wire"
)

View File

@@ -96,7 +96,7 @@ services:
# PostgreSQL Database # PostgreSQL Database
# =========================================================================== # ===========================================================================
postgres: postgres:
image: postgres:15-alpine image: postgres:18-alpine
container_name: sub2api-postgres container_name: sub2api-postgres
restart: unless-stopped restart: unless-stopped
volumes: volumes:

View File

@@ -2,7 +2,7 @@
<html lang="zh-CN"> <html lang="zh-CN">
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" /> <link rel="icon" type="image/png" href="/logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Sub2API - AI API Gateway</title> <title>Sub2API - AI API Gateway</title>
</head> </head>

View File

@@ -11,6 +11,7 @@ import type {
PaginatedResponse, PaginatedResponse,
AccountUsageInfo, AccountUsageInfo,
WindowStats, WindowStats,
ClaudeModel,
} from '@/types'; } from '@/types';
/** /**
@@ -247,6 +248,16 @@ export async function setSchedulable(id: number, schedulable: boolean): Promise<
return data; return data;
} }
/**
* Get available models for an account
* @param id - Account ID
* @returns List of available models for this account
*/
export async function getAvailableModels(id: number): Promise<ClaudeModel[]> {
const { data } = await apiClient.get<ClaudeModel[]>(`/admin/accounts/${id}/models`);
return data;
}
export const accountsAPI = { export const accountsAPI = {
list, list,
getById, getById,
@@ -262,6 +273,7 @@ export const accountsAPI = {
getTodayStats, getTodayStats,
clearRateLimit, clearRateLimit,
setSchedulable, setSchedulable,
getAvailableModels,
generateAuthUrl, generateAuthUrl,
exchangeCode, exchangeCode,
batchCreate, batchCreate,

View File

@@ -36,6 +36,23 @@
</span> </span>
</div> </div>
<!-- Model Selection -->
<div class="space-y-1.5">
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.accounts.selectTestModel') }}
</label>
<select
v-model="selectedModelId"
:disabled="loadingModels || status === 'connecting'"
class="w-full px-3 py-2 text-sm rounded-lg border border-gray-300 dark:border-dark-500 bg-white dark:bg-dark-700 text-gray-900 dark:text-gray-100 focus:ring-2 focus:ring-primary-500 focus:border-primary-500 disabled:opacity-50 disabled:cursor-not-allowed"
>
<option v-if="loadingModels" value="">{{ t('common.loading') }}...</option>
<option v-for="model in availableModels" :key="model.id" :value="model.id">
{{ model.display_name }} ({{ model.id }})
</option>
</select>
</div>
<!-- Terminal Output --> <!-- Terminal Output -->
<div class="relative group"> <div class="relative group">
<div <div
@@ -125,10 +142,10 @@
</button> </button>
<button <button
@click="startTest" @click="startTest"
:disabled="status === 'connecting'" :disabled="status === 'connecting' || !selectedModelId"
:class="[ :class="[
'px-4 py-2 text-sm font-medium rounded-lg transition-all flex items-center gap-2', 'px-4 py-2 text-sm font-medium rounded-lg transition-all flex items-center gap-2',
status === 'connecting' status === 'connecting' || !selectedModelId
? 'bg-primary-400 text-white cursor-not-allowed' ? 'bg-primary-400 text-white cursor-not-allowed'
: status === 'success' : status === 'success'
? 'bg-green-500 hover:bg-green-600 text-white' ? 'bg-green-500 hover:bg-green-600 text-white'
@@ -161,7 +178,8 @@
import { ref, watch, nextTick } from 'vue' import { ref, watch, nextTick } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import Modal from '@/components/common/Modal.vue' import Modal from '@/components/common/Modal.vue'
import type { Account } from '@/types' import { adminAPI } from '@/api/admin'
import type { Account, ClaudeModel } from '@/types'
const { t } = useI18n() const { t } = useI18n()
@@ -184,17 +202,44 @@ const status = ref<'idle' | 'connecting' | 'success' | 'error'>('idle')
const outputLines = ref<OutputLine[]>([]) const outputLines = ref<OutputLine[]>([])
const streamingContent = ref('') const streamingContent = ref('')
const errorMessage = ref('') const errorMessage = ref('')
const availableModels = ref<ClaudeModel[]>([])
const selectedModelId = ref('')
const loadingModels = ref(false)
let eventSource: EventSource | null = null let eventSource: EventSource | null = null
// Reset state when modal opens // Load available models when modal opens
watch(() => props.show, (newVal) => { watch(() => props.show, async (newVal) => {
if (newVal) { if (newVal && props.account) {
resetState() resetState()
await loadAvailableModels()
} else { } else {
closeEventSource() closeEventSource()
} }
}) })
const loadAvailableModels = async () => {
if (!props.account) return
loadingModels.value = true
selectedModelId.value = '' // Reset selection before loading
try {
availableModels.value = await adminAPI.accounts.getAvailableModels(props.account.id)
// Default to first model (usually Sonnet)
if (availableModels.value.length > 0) {
// Try to select Sonnet as default, otherwise use first model
const sonnetModel = availableModels.value.find(m => m.id.includes('sonnet'))
selectedModelId.value = sonnetModel?.id || availableModels.value[0].id
}
} catch (error) {
console.error('Failed to load available models:', error)
// Fallback to empty list
availableModels.value = []
selectedModelId.value = ''
} finally {
loadingModels.value = false
}
}
const resetState = () => { const resetState = () => {
status.value = 'idle' status.value = 'idle'
outputLines.value = [] outputLines.value = []
@@ -227,7 +272,7 @@ const scrollToBottom = async () => {
} }
const startTest = async () => { const startTest = async () => {
if (!props.account) return if (!props.account || !selectedModelId.value) return
resetState() resetState()
status.value = 'connecting' status.value = 'connecting'
@@ -247,7 +292,8 @@ const startTest = async () => {
headers: { headers: {
'Authorization': `Bearer ${localStorage.getItem('auth_token')}`, 'Authorization': `Bearer ${localStorage.getItem('auth_token')}`,
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} },
body: JSON.stringify({ model_id: selectedModelId.value })
}) })
if (!response.ok) { if (!response.ok) {

View File

@@ -418,6 +418,31 @@
</div> </div>
</div> </div>
<!-- Intercept Warmup Requests (all account types) -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.interceptWarmupRequests') }}</label>
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">{{ t('admin.accounts.interceptWarmupRequestsDesc') }}</p>
</div>
<button
type="button"
@click="interceptWarmupRequests = !interceptWarmupRequests"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
interceptWarmupRequests ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
interceptWarmupRequests ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
<div> <div>
<label class="input-label">{{ t('admin.accounts.proxy') }}</label> <label class="input-label">{{ t('admin.accounts.proxy') }}</label>
<ProxySelector <ProxySelector
@@ -590,6 +615,7 @@ const allowedModels = ref<string[]>([])
const customErrorCodesEnabled = ref(false) const customErrorCodesEnabled = ref(false)
const selectedErrorCodes = ref<number[]>([]) const selectedErrorCodes = ref<number[]>([])
const customErrorCodeInput = ref<number | null>(null) const customErrorCodeInput = ref<number | null>(null)
const interceptWarmupRequests = ref(false)
// Common models for whitelist // Common models for whitelist
const commonModels = [ const commonModels = [
@@ -758,6 +784,7 @@ const resetForm = () => {
customErrorCodesEnabled.value = false customErrorCodesEnabled.value = false
selectedErrorCodes.value = [] selectedErrorCodes.value = []
customErrorCodeInput.value = null customErrorCodeInput.value = null
interceptWarmupRequests.value = false
oauth.resetState() oauth.resetState()
oauthFlowRef.value?.reset() oauthFlowRef.value?.reset()
} }
@@ -801,6 +828,11 @@ const handleSubmit = async () => {
credentials.custom_error_codes = [...selectedErrorCodes.value] credentials.custom_error_codes = [...selectedErrorCodes.value]
} }
// Add intercept warmup requests setting
if (interceptWarmupRequests.value) {
credentials.intercept_warmup_requests = true
}
form.credentials = credentials form.credentials = credentials
submitting.value = true submitting.value = true
@@ -847,11 +879,17 @@ const handleExchangeCode = async () => {
const extra = oauth.buildExtraInfo(tokenInfo) const extra = oauth.buildExtraInfo(tokenInfo)
// Merge interceptWarmupRequests into credentials
const credentials = {
...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
}
await adminAPI.accounts.create({ await adminAPI.accounts.create({
name: form.name, name: form.name,
platform: form.platform, platform: form.platform,
type: addMethod.value, // Use addMethod as type: 'oauth' or 'setup-token' type: addMethod.value, // Use addMethod as type: 'oauth' or 'setup-token'
credentials: tokenInfo, credentials,
extra, extra,
proxy_id: form.proxy_id, proxy_id: form.proxy_id,
concurrency: form.concurrency, concurrency: form.concurrency,
@@ -901,11 +939,17 @@ const handleCookieAuth = async (sessionKey: string) => {
const extra = oauth.buildExtraInfo(tokenInfo) const extra = oauth.buildExtraInfo(tokenInfo)
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
// Merge interceptWarmupRequests into credentials
const credentials = {
...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
}
await adminAPI.accounts.create({ await adminAPI.accounts.create({
name: accountName, name: accountName,
platform: form.platform, platform: form.platform,
type: addMethod.value, // Use addMethod as type: 'oauth' or 'setup-token' type: addMethod.value, // Use addMethod as type: 'oauth' or 'setup-token'
credentials: tokenInfo, credentials,
extra, extra,
proxy_id: form.proxy_id, proxy_id: form.proxy_id,
concurrency: form.concurrency, concurrency: form.concurrency,

View File

@@ -286,6 +286,31 @@
</div> </div>
</div> </div>
<!-- Intercept Warmup Requests (all account types) -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.interceptWarmupRequests') }}</label>
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">{{ t('admin.accounts.interceptWarmupRequestsDesc') }}</p>
</div>
<button
type="button"
@click="interceptWarmupRequests = !interceptWarmupRequests"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
interceptWarmupRequests ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
interceptWarmupRequests ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
<div> <div>
<label class="input-label">{{ t('admin.accounts.proxy') }}</label> <label class="input-label">{{ t('admin.accounts.proxy') }}</label>
<ProxySelector <ProxySelector
@@ -401,6 +426,7 @@ const allowedModels = ref<string[]>([])
const customErrorCodesEnabled = ref(false) const customErrorCodesEnabled = ref(false)
const selectedErrorCodes = ref<number[]>([]) const selectedErrorCodes = ref<number[]>([])
const customErrorCodeInput = ref<number | null>(null) const customErrorCodeInput = ref<number | null>(null)
const interceptWarmupRequests = ref(false)
// Common models for whitelist // Common models for whitelist
const commonModels = [ const commonModels = [
@@ -459,6 +485,10 @@ watch(() => props.account, (newAccount) => {
form.status = newAccount.status as 'active' | 'inactive' form.status = newAccount.status as 'active' | 'inactive'
form.group_ids = newAccount.group_ids || [] form.group_ids = newAccount.group_ids || []
// Load intercept warmup requests setting (applies to all account types)
const credentials = newAccount.credentials as Record<string, unknown> | undefined
interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true
// Initialize API Key fields for apikey type // Initialize API Key fields for apikey type
if (newAccount.type === 'apikey' && newAccount.credentials) { if (newAccount.type === 'apikey' && newAccount.credentials) {
const credentials = newAccount.credentials as Record<string, unknown> const credentials = newAccount.credentials as Record<string, unknown>
@@ -630,6 +660,23 @@ const handleSubmit = async () => {
newCredentials.custom_error_codes = [...selectedErrorCodes.value] newCredentials.custom_error_codes = [...selectedErrorCodes.value]
} }
// Add intercept warmup requests setting
if (interceptWarmupRequests.value) {
newCredentials.intercept_warmup_requests = true
}
updatePayload.credentials = newCredentials
} else {
// For oauth/setup-token types, only update intercept_warmup_requests if changed
const currentCredentials = props.account.credentials as Record<string, unknown> || {}
const newCredentials: Record<string, unknown> = { ...currentCredentials }
if (interceptWarmupRequests.value) {
newCredentials.intercept_warmup_requests = true
} else {
delete newCredentials.intercept_warmup_requests
}
updatePayload.credentials = newCredentials updatePayload.credentials = newCredentials
} }

View File

@@ -12,7 +12,8 @@
]" ]"
:title="hasUpdate ? 'New version available' : 'Up to date'" :title="hasUpdate ? 'New version available' : 'Up to date'"
> >
<span class="font-medium">v{{ currentVersion }}</span> <span v-if="currentVersion" class="font-medium">v{{ currentVersion }}</span>
<span v-else class="font-medium w-12 h-3 bg-gray-200 dark:bg-dark-600 rounded animate-pulse"></span>
<!-- Update indicator --> <!-- Update indicator -->
<span v-if="hasUpdate" class="relative flex h-2 w-2"> <span v-if="hasUpdate" class="relative flex h-2 w-2">
<span class="animate-ping absolute inline-flex h-full w-full rounded-full bg-amber-400 opacity-75"></span> <span class="animate-ping absolute inline-flex h-full w-full rounded-full bg-amber-400 opacity-75"></span>
@@ -56,7 +57,8 @@
<!-- Version display - centered and prominent --> <!-- Version display - centered and prominent -->
<div class="text-center mb-4"> <div class="text-center mb-4">
<div class="inline-flex items-center gap-2"> <div class="inline-flex items-center gap-2">
<span class="text-2xl font-bold text-gray-900 dark:text-white">v{{ currentVersion }}</span> <span v-if="currentVersion" class="text-2xl font-bold text-gray-900 dark:text-white">v{{ currentVersion }}</span>
<span v-else class="text-2xl font-bold text-gray-400 dark:text-dark-500">--</span>
<!-- Show check mark when up to date --> <!-- Show check mark when up to date -->
<span v-if="!hasUpdate" class="flex items-center justify-center w-5 h-5 rounded-full bg-green-100 dark:bg-green-900/30"> <span v-if="!hasUpdate" class="flex items-center justify-center w-5 h-5 rounded-full bg-green-100 dark:bg-green-900/30">
<svg class="w-3 h-3 text-green-600 dark:text-green-400" fill="currentColor" viewBox="0 0 20 20"> <svg class="w-3 h-3 text-green-600 dark:text-green-400" fill="currentColor" viewBox="0 0 20 20">
@@ -233,8 +235,8 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, onMounted, onBeforeUnmount } from 'vue'; import { ref, computed, onMounted, onBeforeUnmount } from 'vue';
import { useI18n } from 'vue-i18n'; import { useI18n } from 'vue-i18n';
import { useAuthStore } from '@/stores'; import { useAuthStore, useAppStore } from '@/stores';
import { checkUpdates, performUpdate, restartService, type VersionInfo, type ReleaseInfo } from '@/api/admin/system'; import { performUpdate, restartService } from '@/api/admin/system';
const { t } = useI18n(); const { t } = useI18n();
@@ -243,20 +245,22 @@ const props = defineProps<{
}>(); }>();
const authStore = useAuthStore(); const authStore = useAuthStore();
const appStore = useAppStore();
const isAdmin = computed(() => authStore.isAdmin); const isAdmin = computed(() => authStore.isAdmin);
const loading = ref(false);
const dropdownOpen = ref(false); const dropdownOpen = ref(false);
const dropdownRef = ref<HTMLElement | null>(null); const dropdownRef = ref<HTMLElement | null>(null);
const currentVersion = ref('0.1.0'); // Use store's cached version state
const latestVersion = ref('0.1.0'); const loading = computed(() => appStore.versionLoading);
const hasUpdate = ref(false); const currentVersion = computed(() => appStore.currentVersion || props.version || '');
const releaseInfo = ref<ReleaseInfo | null>(null); const latestVersion = computed(() => appStore.latestVersion);
const buildType = ref('source'); // "source" or "release" const hasUpdate = computed(() => appStore.hasUpdate);
const releaseInfo = computed(() => appStore.releaseInfo);
const buildType = computed(() => appStore.buildType);
// Update process states // Update process states (local to this component)
const updating = ref(false); const updating = ref(false);
const restarting = ref(false); const restarting = ref(false);
const needRestart = ref(false); const needRestart = ref(false);
@@ -277,24 +281,12 @@ function closeDropdown() {
async function refreshVersion(force = true) { async function refreshVersion(force = true) {
if (!isAdmin.value) return; if (!isAdmin.value) return;
loading.value = true; // Reset update states when refreshing
try { updateError.value = '';
const data: VersionInfo = await checkUpdates(force); updateSuccess.value = false;
currentVersion.value = data.current_version; needRestart.value = false;
latestVersion.value = data.latest_version;
buildType.value = data.build_type || 'source'; await appStore.fetchVersion(force);
// Show update indicator for all build types
hasUpdate.value = data.has_update;
releaseInfo.value = data.release_info || null;
// Reset update states when refreshing
updateError.value = '';
updateSuccess.value = false;
needRestart.value = false;
} catch (error) {
console.error('Failed to check updates:', error);
} finally {
loading.value = false;
}
} }
async function handleUpdate() { async function handleUpdate() {
@@ -308,7 +300,8 @@ async function handleUpdate() {
const result = await performUpdate(); const result = await performUpdate();
updateSuccess.value = true; updateSuccess.value = true;
needRestart.value = result.need_restart; needRestart.value = result.need_restart;
hasUpdate.value = false; // Clear version cache to reflect update completed
appStore.clearVersionCache();
} catch (error: unknown) { } catch (error: unknown) {
const err = error as { response?: { data?: { message?: string } }; message?: string }; const err = error as { response?: { data?: { message?: string } }; message?: string };
updateError.value = err.response?.data?.message || err.message || t('version.updateFailed'); updateError.value = err.response?.data?.message || err.message || t('version.updateFailed');
@@ -346,9 +339,8 @@ function handleClickOutside(event: MouseEvent) {
onMounted(() => { onMounted(() => {
if (isAdmin.value) { if (isAdmin.value) {
refreshVersion(false); // Use cached version if available, otherwise fetch
} else if (props.version) { appStore.fetchVersion(false);
currentVersion.value = props.version;
} }
document.addEventListener('click', handleClickOutside); document.addEventListener('click', handleClickOutside);
}); });

View File

@@ -108,7 +108,7 @@
</router-link> </router-link>
<a <a
href="https://github.com/fangyuan99/sub2api" href="https://github.com/Wei-Shaw/sub2api"
target="_blank" target="_blank"
rel="noopener noreferrer" rel="noopener noreferrer"
@click="closeDropdown" @click="closeDropdown"
@@ -207,7 +207,7 @@ const pageDescription = computed(() => {
}); });
function toggleMobileSidebar() { function toggleMobileSidebar() {
appStore.toggleSidebar(); appStore.toggleMobileSidebar();
} }
function toggleDropdown() { function toggleDropdown() {

View File

@@ -36,6 +36,7 @@
class="sidebar-link mb-1" class="sidebar-link mb-1"
:class="{ 'sidebar-link-active': isActive(item.path) }" :class="{ 'sidebar-link-active': isActive(item.path) }"
:title="sidebarCollapsed ? item.label : undefined" :title="sidebarCollapsed ? item.label : undefined"
@click="handleMenuItemClick"
> >
<component :is="item.icon" class="w-5 h-5 flex-shrink-0" /> <component :is="item.icon" class="w-5 h-5 flex-shrink-0" />
<transition name="fade"> <transition name="fade">
@@ -58,6 +59,7 @@
class="sidebar-link mb-1" class="sidebar-link mb-1"
:class="{ 'sidebar-link-active': isActive(item.path) }" :class="{ 'sidebar-link-active': isActive(item.path) }"
:title="sidebarCollapsed ? item.label : undefined" :title="sidebarCollapsed ? item.label : undefined"
@click="handleMenuItemClick"
> >
<component :is="item.icon" class="w-5 h-5 flex-shrink-0" /> <component :is="item.icon" class="w-5 h-5 flex-shrink-0" />
<transition name="fade"> <transition name="fade">
@@ -77,6 +79,7 @@
class="sidebar-link mb-1" class="sidebar-link mb-1"
:class="{ 'sidebar-link-active': isActive(item.path) }" :class="{ 'sidebar-link-active': isActive(item.path) }"
:title="sidebarCollapsed ? item.label : undefined" :title="sidebarCollapsed ? item.label : undefined"
@click="handleMenuItemClick"
> >
<component :is="item.icon" class="w-5 h-5 flex-shrink-0" /> <component :is="item.icon" class="w-5 h-5 flex-shrink-0" />
<transition name="fade"> <transition name="fade">
@@ -142,9 +145,9 @@ const appStore = useAppStore();
const authStore = useAuthStore(); const authStore = useAuthStore();
const sidebarCollapsed = computed(() => appStore.sidebarCollapsed); const sidebarCollapsed = computed(() => appStore.sidebarCollapsed);
const mobileOpen = computed(() => appStore.mobileOpen);
const isAdmin = computed(() => authStore.isAdmin); const isAdmin = computed(() => authStore.isAdmin);
const isDark = ref(document.documentElement.classList.contains('dark')); const isDark = ref(document.documentElement.classList.contains('dark'));
const mobileOpen = ref(false);
// Site settings // Site settings
const siteName = ref('Sub2API'); const siteName = ref('Sub2API');
@@ -303,7 +306,15 @@ function toggleTheme() {
} }
function closeMobile() { function closeMobile() {
mobileOpen.value = false; appStore.setMobileOpen(false);
}
function handleMenuItemClick() {
if (mobileOpen.value) {
setTimeout(() => {
appStore.setMobileOpen(false);
}, 150);
}
} }
function isActive(path: string): boolean { function isActive(path: string): boolean {

View File

@@ -266,6 +266,8 @@ export default {
sync: 'Sync', sync: 'Sync',
in: 'In', in: 'In',
out: 'Out', out: 'Out',
cacheRead: 'Read',
cacheWrite: 'Write',
rate: 'Rate', rate: 'Rate',
original: 'Original', original: 'Original',
billed: 'Billed', billed: 'Billed',
@@ -543,6 +545,7 @@ export default {
title: 'Subscription Settings', title: 'Subscription Settings',
type: 'Billing Type', type: 'Billing Type',
typeHint: 'Standard billing deducts from user balance. Subscription mode uses quota limits instead.', typeHint: 'Standard billing deducts from user balance. Subscription mode uses quota limits instead.',
typeNotEditable: 'Billing type cannot be changed after group creation.',
standard: 'Standard (Balance)', standard: 'Standard (Balance)',
subscription: 'Subscription (Quota)', subscription: 'Subscription (Quota)',
dailyLimit: 'Daily Limit (USD)', dailyLimit: 'Daily Limit (USD)',
@@ -695,6 +698,8 @@ export default {
enterErrorCode: 'Enter error code (100-599)', enterErrorCode: 'Enter error code (100-599)',
invalidErrorCode: 'Please enter a valid HTTP error code (100-599)', invalidErrorCode: 'Please enter a valid HTTP error code (100-599)',
errorCodeExists: 'This error code is already selected', errorCodeExists: 'This error code is already selected',
interceptWarmupRequests: 'Intercept Warmup Requests',
interceptWarmupRequestsDesc: 'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens',
proxy: 'Proxy', proxy: 'Proxy',
noProxy: 'No Proxy', noProxy: 'No Proxy',
concurrency: 'Concurrency', concurrency: 'Concurrency',
@@ -776,6 +781,7 @@ export default {
copyOutput: 'Copy output', copyOutput: 'Copy output',
startingTestForAccount: 'Starting test for account: {name}', startingTestForAccount: 'Starting test for account: {name}',
testAccountTypeLabel: 'Account type: {type}', testAccountTypeLabel: 'Account type: {type}',
selectTestModel: 'Select Test Model',
testModel: 'claude-sonnet-4-5-20250929', testModel: 'claude-sonnet-4-5-20250929',
testPrompt: 'Prompt: "hi"', testPrompt: 'Prompt: "hi"',
}, },
@@ -816,8 +822,8 @@ export default {
standardAdd: 'Standard Add', standardAdd: 'Standard Add',
batchAdd: 'Quick Add', batchAdd: 'Quick Add',
batchInput: 'Proxy List', batchInput: 'Proxy List',
batchInputPlaceholder: 'Enter one proxy per line in the following formats:\nsocks5://user:pass@192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass@proxy.example.com:443', batchInputPlaceholder: "Enter one proxy per line in the following formats:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
batchInputHint: 'Supports http, https, socks5 protocols. Format: protocol://[user:pass@]host:port', batchInputHint: "Supports http, https, socks5 protocols. Format: protocol://[user:pass{'@'}]host:port",
parsedCount: '{count} valid', parsedCount: '{count} valid',
invalidCount: '{count} invalid', invalidCount: '{count} invalid',
duplicateCount: '{count} duplicate', duplicateCount: '{count} duplicate',

View File

@@ -266,6 +266,8 @@ export default {
sync: '同步', sync: '同步',
in: '输入', in: '输入',
out: '输出', out: '输出',
cacheRead: '读取',
cacheWrite: '写入',
rate: '倍率', rate: '倍率',
original: '原始', original: '原始',
billed: '计费', billed: '计费',
@@ -598,6 +600,7 @@ export default {
title: '订阅设置', title: '订阅设置',
type: '计费类型', type: '计费类型',
typeHint: '标准计费从用户余额扣除。订阅模式使用配额限制。', typeHint: '标准计费从用户余额扣除。订阅模式使用配额限制。',
typeNotEditable: '分组创建后无法修改计费类型。',
standard: '标准(余额)', standard: '标准(余额)',
subscription: '订阅(配额)', subscription: '订阅(配额)',
dailyLimit: '每日限额USD', dailyLimit: '每日限额USD',
@@ -785,6 +788,8 @@ export default {
enterErrorCode: '输入错误码 (100-599)', enterErrorCode: '输入错误码 (100-599)',
invalidErrorCode: '请输入有效的 HTTP 错误码 (100-599)', invalidErrorCode: '请输入有效的 HTTP 错误码 (100-599)',
errorCodeExists: '该错误码已被选中', errorCodeExists: '该错误码已被选中',
interceptWarmupRequests: '拦截预热请求',
interceptWarmupRequestsDesc: '启用后,标题生成等预热请求将返回 mock 响应,不消耗上游 token',
proxy: '代理', proxy: '代理',
noProxy: '无代理', noProxy: '无代理',
concurrency: '并发数', concurrency: '并发数',
@@ -864,6 +869,7 @@ export default {
copyOutput: '复制输出', copyOutput: '复制输出',
startingTestForAccount: '开始测试账号:{name}', startingTestForAccount: '开始测试账号:{name}',
testAccountTypeLabel: '账号类型:{type}', testAccountTypeLabel: '账号类型:{type}',
selectTestModel: '选择测试模型',
testModel: 'claude-sonnet-4-5-20250929', testModel: 'claude-sonnet-4-5-20250929',
testPrompt: '提示词:"hi"', testPrompt: '提示词:"hi"',
}, },
@@ -941,8 +947,8 @@ export default {
standardAdd: '标准添加', standardAdd: '标准添加',
batchAdd: '快捷添加', batchAdd: '快捷添加',
batchInput: '代理列表', batchInput: '代理列表',
batchInputPlaceholder: '每行输入一个代理,支持以下格式:\nsocks5://user:pass@192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass@proxy.example.com:443', batchInputPlaceholder: "每行输入一个代理,支持以下格式:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
batchInputHint: '支持 http、https、socks5 协议,格式:协议://[用户名:密码@]主机:端口', batchInputHint: "支持 http、https、socks5 协议,格式:协议://[用户名:密码{'@'}]主机:端口",
parsedCount: '有效 {count} 个', parsedCount: '有效 {count} 个',
invalidCount: '无效 {count} 个', invalidCount: '无效 {count} 个',
duplicateCount: '重复 {count} 个', duplicateCount: '重复 {count} 个',

View File

@@ -305,9 +305,10 @@ router.beforeEach((to, _from, next) => {
// If route doesn't require auth, allow access // If route doesn't require auth, allow access
if (!requiresAuth) { if (!requiresAuth) {
// If already authenticated and trying to access login/register, redirect to dashboard // If already authenticated and trying to access login/register, redirect to appropriate dashboard
if (authStore.isAuthenticated && (to.path === '/login' || to.path === '/register')) { if (authStore.isAuthenticated && (to.path === '/login' || to.path === '/register')) {
next('/dashboard'); // Admin users go to admin dashboard, regular users go to user dashboard
next(authStore.isAdmin ? '/admin/dashboard' : '/dashboard');
return; return;
} }
next(); next();

View File

@@ -6,14 +6,25 @@
import { defineStore } from 'pinia'; import { defineStore } from 'pinia';
import { ref, computed } from 'vue'; import { ref, computed } from 'vue';
import type { Toast, ToastType } from '@/types'; import type { Toast, ToastType } from '@/types';
import { checkUpdates as checkUpdatesAPI, type VersionInfo, type ReleaseInfo } from '@/api/admin/system';
export const useAppStore = defineStore('app', () => { export const useAppStore = defineStore('app', () => {
// ==================== State ==================== // ==================== State ====================
const sidebarCollapsed = ref<boolean>(false); const sidebarCollapsed = ref<boolean>(false);
const mobileOpen = ref<boolean>(false);
const loading = ref<boolean>(false); const loading = ref<boolean>(false);
const toasts = ref<Toast[]>([]); const toasts = ref<Toast[]>([]);
// Version cache state
const versionLoaded = ref<boolean>(false);
const versionLoading = ref<boolean>(false);
const currentVersion = ref<string>('');
const latestVersion = ref<string>('');
const hasUpdate = ref<boolean>(false);
const buildType = ref<string>('source');
const releaseInfo = ref<ReleaseInfo | null>(null);
// Auto-incrementing ID for toasts // Auto-incrementing ID for toasts
let toastIdCounter = 0; let toastIdCounter = 0;
@@ -40,6 +51,21 @@ export const useAppStore = defineStore('app', () => {
sidebarCollapsed.value = collapsed; sidebarCollapsed.value = collapsed;
} }
/**
* Toggle mobile sidebar open state
*/
function toggleMobileSidebar(): void {
mobileOpen.value = !mobileOpen.value;
}
/**
* Set mobile sidebar open state explicitly
* @param open - Whether mobile sidebar should be open
*/
function setMobileOpen(open: boolean): void {
mobileOpen.value = open;
}
/** /**
* Set global loading state * Set global loading state
* @param isLoading - Whether app is in loading state * @param isLoading - Whether app is in loading state
@@ -192,20 +218,82 @@ export const useAppStore = defineStore('app', () => {
toasts.value = []; toasts.value = [];
} }
// ==================== Version Management ====================
/**
* Fetch version info (uses cache unless force=true)
* @param force - Force refresh from API
*/
async function fetchVersion(force = false): Promise<VersionInfo | null> {
// Return cached data if available and not forcing refresh
if (versionLoaded.value && !force) {
return {
current_version: currentVersion.value,
latest_version: latestVersion.value,
has_update: hasUpdate.value,
build_type: buildType.value,
release_info: releaseInfo.value || undefined,
cached: true,
};
}
// Prevent duplicate requests
if (versionLoading.value) {
return null;
}
versionLoading.value = true;
try {
const data = await checkUpdatesAPI(force);
currentVersion.value = data.current_version;
latestVersion.value = data.latest_version;
hasUpdate.value = data.has_update;
buildType.value = data.build_type || 'source';
releaseInfo.value = data.release_info || null;
versionLoaded.value = true;
return data;
} catch (error) {
console.error('Failed to fetch version:', error);
return null;
} finally {
versionLoading.value = false;
}
}
/**
* Clear version cache (e.g., after update)
*/
function clearVersionCache(): void {
versionLoaded.value = false;
hasUpdate.value = false;
}
// ==================== Return Store API ==================== // ==================== Return Store API ====================
return { return {
// State // State
sidebarCollapsed, sidebarCollapsed,
mobileOpen,
loading, loading,
toasts, toasts,
// Version state
versionLoaded,
versionLoading,
currentVersion,
latestVersion,
hasUpdate,
buildType,
releaseInfo,
// Computed // Computed
hasActiveToasts, hasActiveToasts,
// Actions // Actions
toggleSidebar, toggleSidebar,
setSidebarCollapsed, setSidebarCollapsed,
toggleMobileSidebar,
setMobileOpen,
setLoading, setLoading,
showToast, showToast,
showSuccess, showSuccess,
@@ -217,5 +305,9 @@ export const useAppStore = defineStore('app', () => {
withLoading, withLoading,
withLoadingAndError, withLoadingAndError,
reset, reset,
// Version actions
fetchVersion,
clearVersionCache,
}; };
}); });

View File

@@ -285,6 +285,14 @@ export type AccountType = 'oauth' | 'setup-token' | 'apikey';
export type OAuthAddMethod = 'oauth' | 'setup-token'; export type OAuthAddMethod = 'oauth' | 'setup-token';
export type ProxyProtocol = 'http' | 'https' | 'socks5'; export type ProxyProtocol = 'http' | 'https' | 'socks5';
// Claude Model type (returned by /v1/models and account models API)
export interface ClaudeModel {
id: string;
type: string;
display_name: string;
created_at: string;
}
export interface Proxy { export interface Proxy {
id: number; id: number;
name: string; name: string;

View File

@@ -282,7 +282,7 @@ const siteSubtitle = ref('AI API Gateway Platform');
const isDark = ref(document.documentElement.classList.contains('dark')); const isDark = ref(document.documentElement.classList.contains('dark'));
// GitHub URL // GitHub URL
const githubUrl = 'https://github.com/fangyuan99/sub2api'; const githubUrl = 'https://github.com/Wei-Shaw/sub2api';
// Auth state // Auth state
const isAuthenticated = computed(() => authStore.isAuthenticated); const isAuthenticated = computed(() => authStore.isAuthenticated);

View File

@@ -406,7 +406,7 @@ const lineOptions = computed(() => ({
// Model chart data // Model chart data
const modelChartData = computed(() => { const modelChartData = computed(() => {
if (!modelStats.value.length) return null if (!modelStats.value?.length) return null
const colors = [ const colors = [
'#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6',
@@ -425,7 +425,7 @@ const modelChartData = computed(() => {
// Trend chart data // Trend chart data
const trendChartData = computed(() => { const trendChartData = computed(() => {
if (!trendData.value.length) return null if (!trendData.value?.length) return null
return { return {
labels: trendData.value.map(d => d.date), labels: trendData.value.map(d => d.date),
@@ -460,7 +460,7 @@ const trendChartData = computed(() => {
// User trend chart data // User trend chart data
const userTrendChartData = computed(() => { const userTrendChartData = computed(() => {
if (!userTrend.value.length) return null if (!userTrend.value?.length) return null
// Group by user // Group by user
const userGroups = new Map<string, { name: string; data: Map<string, number> }>() const userGroups = new Map<string, { name: string; data: Map<string, number> }>()

View File

@@ -180,7 +180,7 @@
/> />
<p class="input-hint">{{ t('admin.groups.rateMultiplierHint') }}</p> <p class="input-hint">{{ t('admin.groups.rateMultiplierHint') }}</p>
</div> </div>
<div class="flex items-center gap-3"> <div v-if="createForm.subscription_type !== 'subscription'" class="flex items-center gap-3">
<button <button
type="button" type="button"
@click="createForm.is_exclusive = !createForm.is_exclusive" @click="createForm.is_exclusive = !createForm.is_exclusive"
@@ -323,7 +323,7 @@
class="input" class="input"
/> />
</div> </div>
<div class="flex items-center gap-3"> <div v-if="editForm.subscription_type !== 'subscription'" class="flex items-center gap-3">
<button <button
type="button" type="button"
@click="editForm.is_exclusive = !editForm.is_exclusive" @click="editForm.is_exclusive = !editForm.is_exclusive"
@@ -360,8 +360,9 @@
<Select <Select
v-model="editForm.subscription_type" v-model="editForm.subscription_type"
:options="subscriptionTypeOptions" :options="subscriptionTypeOptions"
:disabled="true"
/> />
<p class="input-hint">{{ t('admin.groups.subscription.typeHint') }}</p> <p class="input-hint">{{ t('admin.groups.subscription.typeNotEditable') }}</p>
</div> </div>
<!-- Subscription limits (only show when subscription type is selected) --> <!-- Subscription limits (only show when subscription type is selected) -->
@@ -676,16 +677,11 @@ const confirmDelete = async () => {
} }
} }
// 监听 subscription_type 变化,配额模式时重置 rate_multiplier 为 1 // 监听 subscription_type 变化,订阅模式时重置 rate_multiplier 为 1is_exclusive 为 true
watch(() => createForm.subscription_type, (newVal) => { watch(() => createForm.subscription_type, (newVal) => {
if (newVal === 'subscription') { if (newVal === 'subscription') {
createForm.rate_multiplier = 1.0 createForm.rate_multiplier = 1.0
} createForm.is_exclusive = true
})
watch(() => editForm.subscription_type, (newVal) => {
if (newVal === 'subscription') {
editForm.rate_multiplier = 1.0
} }
}) })

View File

@@ -647,10 +647,10 @@ const parseProxyUrl = (line: string): {
return { return {
protocol: protocol.toLowerCase() as ProxyProtocol, protocol: protocol.toLowerCase() as ProxyProtocol,
host, host: host.trim(),
port: portNum, port: portNum,
username: username || '', username: username?.trim() || '',
password: password || '' password: password?.trim() || ''
} }
} }
@@ -714,9 +714,12 @@ const handleCreateProxy = async () => {
submitting.value = true submitting.value = true
try { try {
await adminAPI.proxies.create({ await adminAPI.proxies.create({
...createForm, name: createForm.name.trim(),
username: createForm.username || null, protocol: createForm.protocol,
password: createForm.password || null host: createForm.host.trim(),
port: createForm.port,
username: createForm.username.trim() || null,
password: createForm.password.trim() || null
}) })
appStore.showSuccess(t('admin.proxies.proxyCreated')) appStore.showSuccess(t('admin.proxies.proxyCreated'))
closeCreateModal() closeCreateModal()
@@ -752,17 +755,18 @@ const handleUpdateProxy = async () => {
submitting.value = true submitting.value = true
try { try {
const updateData: any = { const updateData: any = {
name: editForm.name, name: editForm.name.trim(),
protocol: editForm.protocol, protocol: editForm.protocol,
host: editForm.host, host: editForm.host.trim(),
port: editForm.port, port: editForm.port,
username: editForm.username || null, username: editForm.username.trim() || null,
status: editForm.status status: editForm.status
} }
// Only include password if it was changed // Only include password if it was changed
if (editForm.password) { const trimmedPassword = editForm.password.trim()
updateData.password = editForm.password if (trimmedPassword) {
updateData.password = trimmedPassword
} }
await adminAPI.proxies.update(editingProxy.value.id, updateData) await adminAPI.proxies.update(editingProxy.value.id, updateData)

View File

@@ -43,13 +43,12 @@
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z" /> <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg> </svg>
</div> </div>
<div> <div class="min-w-0 flex-1">
<p class="text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('usage.totalCost') }}</p> <p class="text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('usage.totalCost') }}</p>
<div class="flex items-baseline gap-2"> <p class="text-xl font-bold text-green-600 dark:text-green-400">${{ (usageStats?.total_actual_cost || 0).toFixed(4) }}</p>
<p class="text-xl font-bold text-green-600 dark:text-green-400">${{ (usageStats?.total_actual_cost || 0).toFixed(4) }}</p> <p class="text-xs text-gray-500 dark:text-gray-400">
<span class="text-xs text-gray-400 dark:text-gray-500 line-through">${{ (usageStats?.total_cost || 0).toFixed(4) }}</span> {{ t('usage.actualCost') }} / <span class="line-through">${{ (usageStats?.total_cost || 0).toFixed(4) }}</span> {{ t('usage.standardCost') }}
</div> </p>
<p class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.actualCost') }} / {{ t('usage.standardCost') }}</p>
</div> </div>
</div> </div>
</div> </div>
@@ -195,17 +194,40 @@
</template> </template>
<template #cell-tokens="{ row }"> <template #cell-tokens="{ row }">
<div class="text-sm"> <div class="text-sm space-y-1.5">
<div class="flex items-center gap-1"> <!-- Input / Output Tokens -->
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.in') }}</span> <div class="flex items-center gap-2">
<span class="font-medium text-gray-900 dark:text-white">{{ row.input_tokens.toLocaleString() }}</span> <!-- Input -->
<span class="text-gray-400 dark:text-gray-500">/</span> <div class="inline-flex items-center gap-1">
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.out') }}</span> <svg class="w-3.5 h-3.5 text-emerald-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<span class="font-medium text-gray-900 dark:text-white">{{ row.output_tokens.toLocaleString() }}</span> <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
</svg>
<span class="font-medium text-gray-900 dark:text-white">{{ row.input_tokens.toLocaleString() }}</span>
</div>
<!-- Output -->
<div class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-violet-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 10l7-7m0 0l7 7m-7-7v18" />
</svg>
<span class="font-medium text-gray-900 dark:text-white">{{ row.output_tokens.toLocaleString() }}</span>
</div>
</div> </div>
<div v-if="row.cache_read_tokens > 0" class="flex items-center gap-1 text-blue-600 dark:text-blue-400"> <!-- Cache Tokens (Read + Write) -->
<span>{{ t('dashboard.cache') }}</span> <div v-if="row.cache_read_tokens > 0 || row.cache_creation_tokens > 0" class="flex items-center gap-2">
<span class="font-medium">{{ row.cache_read_tokens.toLocaleString() }}</span> <!-- Cache Read -->
<div v-if="row.cache_read_tokens > 0" class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-sky-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 8h14M5 8a2 2 0 110-4h14a2 2 0 110 4M5 8v10a2 2 0 002 2h10a2 2 0 002-2V8m-9 4h4" />
</svg>
<span class="text-sky-600 dark:text-sky-400 font-medium">{{ formatCacheTokens(row.cache_read_tokens) }}</span>
</div>
<!-- Cache Write -->
<div v-if="row.cache_creation_tokens > 0" class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-amber-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" />
</svg>
<span class="text-amber-600 dark:text-amber-400 font-medium">{{ formatCacheTokens(row.cache_creation_tokens) }}</span>
</div>
</div> </div>
</div> </div>
</template> </template>
@@ -458,6 +480,16 @@ const formatTokens = (value: number): string => {
return value.toLocaleString() return value.toLocaleString()
} }
// Compact format for cache tokens in table cells
const formatCacheTokens = (value: number): string => {
if (value >= 1_000_000) {
return `${(value / 1_000_000).toFixed(1)}M`
} else if (value >= 1_000) {
return `${(value / 1_000).toFixed(1)}K`
}
return value.toLocaleString()
}
const formatDateTime = (dateString: string): string => { const formatDateTime = (dateString: string): string => {
const date = new Date(dateString) const date = new Date(dateString)
return date.toLocaleString('en-US', { return date.toLocaleString('en-US', {
@@ -538,7 +570,7 @@ const exportToCSV = () => {
return return
} }
const headers = ['User', 'API Key', 'Model', 'Type', 'Input Tokens', 'Output Tokens', 'Cache Tokens', 'Total Cost', 'Billing Type', 'Duration (ms)', 'Time'] const headers = ['User', 'API Key', 'Model', 'Type', 'Input Tokens', 'Output Tokens', 'Cache Read Tokens', 'Cache Write Tokens', 'Total Cost', 'Billing Type', 'Duration (ms)', 'Time']
const rows = usageLogs.value.map(log => [ const rows = usageLogs.value.map(log => [
log.user?.email || '', log.user?.email || '',
log.api_key?.name || '', log.api_key?.name || '',
@@ -547,6 +579,7 @@ const exportToCSV = () => {
log.input_tokens, log.input_tokens,
log.output_tokens, log.output_tokens,
log.cache_read_tokens, log.cache_read_tokens,
log.cache_creation_tokens,
log.total_cost.toFixed(6), log.total_cost.toFixed(6),
log.billing_type === 1 ? 'Subscription' : 'Balance', log.billing_type === 1 ? 'Subscription' : 'Balance',
log.duration_ms, log.duration_ms,

View File

@@ -531,7 +531,7 @@ const lineOptions = computed(() => ({
// Model chart data // Model chart data
const modelChartData = computed(() => { const modelChartData = computed(() => {
if (!modelStats.value.length) return null if (!modelStats.value?.length) return null
const colors = [ const colors = [
'#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6',
@@ -550,7 +550,7 @@ const modelChartData = computed(() => {
// Trend chart data // Trend chart data
const trendChartData = computed(() => { const trendChartData = computed(() => {
if (!trendData.value.length) return null if (!trendData.value?.length) return null
return { return {
labels: trendData.value.map(d => d.date), labels: trendData.value.map(d => d.date),
@@ -688,8 +688,9 @@ const loadChartData = async () => {
usageAPI.getDashboardModels({ start_date: startDate.value, end_date: endDate.value }), usageAPI.getDashboardModels({ start_date: startDate.value, end_date: endDate.value }),
]) ])
trendData.value = trendResponse.trend // Ensure we always have arrays, even if API returns null
modelStats.value = modelResponse.models trendData.value = trendResponse.trend || []
modelStats.value = modelResponse.models || []
} catch (error) { } catch (error) {
console.error('Error loading chart data:', error) console.error('Error loading chart data:', error)
} }

View File

@@ -43,13 +43,12 @@
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z" /> <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg> </svg>
</div> </div>
<div> <div class="min-w-0 flex-1">
<p class="text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('usage.totalCost') }}</p> <p class="text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('usage.totalCost') }}</p>
<div class="flex items-baseline gap-2"> <p class="text-xl font-bold text-green-600 dark:text-green-400">${{ (usageStats?.total_actual_cost || 0).toFixed(4) }}</p>
<p class="text-xl font-bold text-green-600 dark:text-green-400">${{ (usageStats?.total_actual_cost || 0).toFixed(4) }}</p> <p class="text-xs text-gray-500 dark:text-gray-400">
<span class="text-xs text-gray-400 dark:text-gray-500 line-through">${{ (usageStats?.total_cost || 0).toFixed(4) }}</span> {{ t('usage.actualCost') }} / <span class="line-through">${{ (usageStats?.total_cost || 0).toFixed(4) }}</span> {{ t('usage.standardCost') }}
</div> </p>
<p class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.actualCost') }} / {{ t('usage.standardCost') }}</p>
</div> </div>
</div> </div>
</div> </div>
@@ -138,17 +137,40 @@
</template> </template>
<template #cell-tokens="{ row }"> <template #cell-tokens="{ row }">
<div class="text-sm"> <div class="text-sm space-y-1.5">
<div class="flex items-center gap-1"> <!-- Input / Output Tokens -->
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.in') }}</span> <div class="flex items-center gap-2">
<span class="font-medium text-gray-900 dark:text-white">{{ row.input_tokens.toLocaleString() }}</span> <!-- Input -->
<span class="text-gray-400 dark:text-gray-500">/</span> <div class="inline-flex items-center gap-1">
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.out') }}</span> <svg class="w-3.5 h-3.5 text-emerald-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<span class="font-medium text-gray-900 dark:text-white">{{ row.output_tokens.toLocaleString() }}</span> <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
</svg>
<span class="font-medium text-gray-900 dark:text-white">{{ row.input_tokens.toLocaleString() }}</span>
</div>
<!-- Output -->
<div class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-violet-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 10l7-7m0 0l7 7m-7-7v18" />
</svg>
<span class="font-medium text-gray-900 dark:text-white">{{ row.output_tokens.toLocaleString() }}</span>
</div>
</div> </div>
<div v-if="row.cache_read_tokens > 0" class="flex items-center gap-1 text-blue-600 dark:text-blue-400"> <!-- Cache Tokens (Read + Write) -->
<span>{{ t('dashboard.cache') }}</span> <div v-if="row.cache_read_tokens > 0 || row.cache_creation_tokens > 0" class="flex items-center gap-2">
<span class="font-medium">{{ row.cache_read_tokens.toLocaleString() }}</span> <!-- Cache Read -->
<div v-if="row.cache_read_tokens > 0" class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-sky-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 8h14M5 8a2 2 0 110-4h14a2 2 0 110 4M5 8v10a2 2 0 002 2h10a2 2 0 002-2V8m-9 4h4" />
</svg>
<span class="text-sky-600 dark:text-sky-400 font-medium">{{ formatCacheTokens(row.cache_read_tokens) }}</span>
</div>
<!-- Cache Write -->
<div v-if="row.cache_creation_tokens > 0" class="inline-flex items-center gap-1">
<svg class="w-3.5 h-3.5 text-amber-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" />
</svg>
<span class="text-amber-600 dark:text-amber-400 font-medium">{{ formatCacheTokens(row.cache_creation_tokens) }}</span>
</div>
</div> </div>
</div> </div>
</template> </template>
@@ -332,6 +354,16 @@ const formatTokens = (value: number): string => {
return value.toLocaleString() return value.toLocaleString()
} }
// Compact format for cache tokens in table cells
const formatCacheTokens = (value: number): string => {
if (value >= 1_000_000) {
return `${(value / 1_000_000).toFixed(1)}M`
} else if (value >= 1_000) {
return `${(value / 1_000).toFixed(1)}K`
}
return value.toLocaleString()
}
const formatDateTime = (dateString: string): string => { const formatDateTime = (dateString: string): string => {
const date = new Date(dateString) const date = new Date(dateString)
return date.toLocaleString('en-US', { return date.toLocaleString('en-US', {
@@ -416,13 +448,14 @@ const exportToCSV = () => {
return return
} }
const headers = ['Model', 'Type', 'Input Tokens', 'Output Tokens', 'Cache Tokens', 'Total Cost', 'Billing Type', 'First Token (ms)', 'Duration (ms)', 'Time'] const headers = ['Model', 'Type', 'Input Tokens', 'Output Tokens', 'Cache Read Tokens', 'Cache Write Tokens', 'Total Cost', 'Billing Type', 'First Token (ms)', 'Duration (ms)', 'Time']
const rows = usageLogs.value.map(log => [ const rows = usageLogs.value.map(log => [
log.model, log.model,
log.stream ? 'Stream' : 'Sync', log.stream ? 'Stream' : 'Sync',
log.input_tokens, log.input_tokens,
log.output_tokens, log.output_tokens,
log.cache_read_tokens, log.cache_read_tokens,
log.cache_creation_tokens,
log.total_cost.toFixed(6), log.total_cost.toFixed(6),
log.billing_type === 1 ? 'Subscription' : 'Balance', log.billing_type === 1 ? 'Subscription' : 'Balance',
log.first_token_ms ?? '', log.first_token_ms ?? '',