Compare commits

..

38 Commits

Author SHA1 Message Date
Wesley Liddick
39fad63ccf Merge pull request #345 from whoismonay/main
mod(frontend): 管理员订阅/兑换码分组选择展示备注
2026-01-20 16:22:53 +08:00
Wesley Liddick
5602d02b1b Merge pull request #343 from mt21625457/main
fix(调度): 完善粘性会话清理与账号调度刷新 和 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint
2026-01-20 16:05:53 +08:00
shaw
81989eed1c test: add promo_code_enabled to API contract test 2026-01-20 16:02:49 +08:00
shaw
192efb84a0 feat(promo-code): complete promo code feature implementation
- Add promo_code_enabled field to SystemSettings and PublicSettings DTOs
- Add promo code validation in registration flow
- Add admin settings UI for promo code configuration
- Add i18n translations for promo code feature
2026-01-20 15:56:26 +08:00
shaw
8672347f93 fix(settings): add missing promo_code_enabled field in public settings API
The field was defined in DTO but not mapped in handler response.
2026-01-20 15:49:57 +08:00
yangjianbo
5e5d4a513b feat: 移动镜像脚本位置 2026-01-20 15:11:27 +08:00
墨颜
88b6358472 build(frontend): vite 加载开发环境变量
- 使用 loadEnv 读取 VITE_DEV_PROXY_TARGET/VITE_DEV_PORT
- 注入 public settings 与 dev proxy 使用同源后端地址
2026-01-20 15:04:18 +08:00
墨颜
dd8d5e2c42 mod(frontend): 订阅分组下拉显示备注
- 订阅管理/分配订阅:下拉项展示分组备注
- 兑换码/订阅类型:下拉项展示分组备注
- 复用 GroupOptionItem/GroupBadge 保持一致体验
2026-01-20 15:02:48 +08:00
shaw
d91e2328fb test(tlsfingerprint): add multi-profile fingerprint verification test
Add TestAllProfiles to verify TLS fingerprint configurations from
config.yaml against tls.peet.ws. Tests check JA4 cipher hash (stable
part) to validate fingerprint spoofing works correctly.
2026-01-20 14:53:15 +08:00
yangjianbo
2a16735495 fix(测试): 修复 SelectAccountWithLoadAwareness 调用缺少参数
为 gateway_multiplatform_test.go 中的 SelectAccountWithLoadAwareness
调用添加缺少的第6个参数 metadataUserID,修复 CI 测试编译错误。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 14:16:46 +08:00
yangjianbo
292f25f9ca Merge branch 'main' of https://github.com/mt21625457/aicodex2api 2026-01-20 14:02:08 +08:00
yangjianbo
c92e37775a Merge branch 'dev' 2026-01-20 13:57:08 +08:00
yangjianbo
f6ed3d1456 Merge branch 'test' into dev 2026-01-20 11:59:13 +08:00
yangjianbo
84686753e8 Merge branch 'test' of https://github.com/mt21625457/aicodex2api into test 2026-01-20 11:51:44 +08:00
yangjianbo
91f01309da fix(调度): 完善粘性会话清理与账号调度刷新
- Update/BulkUpdate 按不可调度字段触发缓存刷新
- GatewayCache 支持多前缀会话键清理
- 模型路由与混合调度优化粘性会话处理
- 补充调度与缓存相关测试覆盖
2026-01-20 11:40:55 +08:00
yangjianbo
57a1fc9d33 style(仓储): 格式化账号仓储
- gofmt 修正 lint 格式提示
2026-01-20 11:30:36 +08:00
shaw
c95a864975 docs: add TLS fingerprint tool link 2026-01-20 11:30:10 +08:00
yangjianbo
7a83db6180 fix(调度): 完善粘性会话清理与账号调度刷新
- Update/BulkUpdate 按不可调度字段触发缓存刷新
- GatewayCache 支持多前缀会话键清理
- 模型路由与混合调度优化粘性会话处理
- 补充调度与缓存相关测试覆盖
2026-01-20 11:19:32 +08:00
Wesley Liddick
a8513da7ff Merge pull request #335 from geminiwen/main
feat(subscription): 支持调整订阅时长(延长/缩短)
2026-01-20 08:52:19 +08:00
Gemini Wen
53534d3956 style(admin): 统一列设置按钮位置到刷新按钮右侧
将订阅管理和账号管理页面的列设置按钮移动到刷新按钮右侧,
与用户管理页面保持一致的布局。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:18:51 +08:00
Gemini Wen
cc07a0e295 feat(subscription): 支持调整订阅时长(延长/缩短)
- 将"延长订阅"功能改为"调整订阅",支持正数延长、负数缩短
- 后端验证:调整天数范围 -36500 到 36500,缩短后剩余天数必须 > 0
- 前端同步更新界面文案和验证逻辑

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:11:30 +08:00
Wesley Liddick
e7bc62500b Merge pull request #333 from whoismonay/main
fix: 普通用户接口移除管理员敏感字段透传
2026-01-19 21:35:51 +08:00
墨颜
c8fb9ef3a5 style(dto): 修复 gofmt 格式问题
- 修复 mappers.go 中 Notes 字段的对齐格式
- 修复 types.go 中 BulkAssignResult 结构体字段的 JSON tag 对齐

修复 golangci-lint 检查中的 gofmt 格式错误
2026-01-19 21:26:30 +08:00
Wesley Liddick
eb5e6214bc Merge pull request #332 from geminiwen/main
fix: 修复手动刷新令牌后缓存未清除导致403错误的问题
2026-01-19 20:53:06 +08:00
shaw
568d6ee10e fix: 修复测试缺少新增设置字段 2026-01-19 20:52:05 +08:00
墨颜
6aef1af76e fix(redeem): 用户兑换历史不返回备注
- 用户侧 RedeemCode DTO 移除 notes 字段,避免泄露内部备注\n- 新增 AdminRedeemCode,并调整管理员兑换码接口继续返回 notes\n- 增加 /api/v1/redeem/history 契约测试,确保用户侧响应不包含 notes
2026-01-19 20:09:35 +08:00
Gemini Wen
a54852e129 fix: 补充API契约测试中缺失的hide_ccs_import_button字段
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 20:06:36 +08:00
Gemini Wen
668118def1 fix: 修复遗漏的测试文件更新和lint错误
- 更新api_contract_test.go以匹配NewAccountHandler新增的tokenCacheInvalidator参数
- 修复errcheck lint错误,显式忽略c.Error()返回值

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 19:58:09 +08:00
yangjianbo
73e6b160f8 feat(认证): 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint
为共享 req 客户端增加 HTTP/2 选项与缓存隔离
OpenAI OAuth 超时提升到 120s,并按协议控制强制
新增客户端池与 OAuth 客户端单测覆盖
修复 usage cleanup 相关 errcheck/ineffassign/staticcheck 并统一格式

测试: make test
2026-01-19 19:50:57 +08:00
Gemini Wen
6fec141de6 fix: 修复手动刷新令牌后缓存未清除导致403错误的问题
手动刷新令牌后,新token保存到数据库但Redis缓存未清除,
导致下游请求仍然使用旧的失效token,上游API返回403错误。

修复方案:在AccountHandler中注入TokenCacheInvalidator,
刷新令牌成功后调用InvalidateToken清除缓存。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 19:40:43 +08:00
墨颜
31cde6c555 fix(subscriptions): 用户订阅不返回分配信息
- 用户侧 UserSubscription DTO 移除 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段\n- 新增 AdminUserSubscription,并调整管理员订阅接口与批量分配结果使用\n- 增加 /api/v1/subscriptions 契约测试,确保用户侧响应不包含上述字段
2026-01-19 19:35:13 +08:00
shaw
b1a980f344 feat: 添加隐藏CCS导入按钮的设置选项
在管理后台设置页面新增开关,允许管理员隐藏API Keys页面的"导入CCS"按钮
2026-01-19 19:25:16 +08:00
墨颜
00d9fbd220 fix(user): 普通用户接口不返回备注
- 用户侧 dto.User 移除 notes 字段,避免泄露管理员备注\n- 新增 dto.AdminUser 并调整 /admin/users 系列接口使用\n- 前端拆分 User/AdminUser,管理端用户页面使用 AdminUser\n- 更新契约测试:/api/v1/auth/me 响应不包含 notes
2026-01-19 19:23:51 +08:00
墨颜
4f4c9679bf fix(groups): 用户分组不下发内部路由信息
- 普通用户 Group DTO 移除 model_routing/account_count/account_groups,避免泄露内部路由与账号信息\n- 新增 AdminGroup DTO,并仅在管理员分组接口返回完整字段\n- 前端拆分 Group/AdminGroup,管理端页面与 API 使用 AdminGroup\n- 增加 /api/v1/groups/available 契约测试,防止回归
2026-01-19 18:58:42 +08:00
shaw
3dab71729d feat: usage接口支持TLS指纹和缓存User-Agent 2026-01-19 17:06:16 +08:00
墨颜
2f6f758670 fix(usage): 用户使用记录不下发账号计费倍率
- 将 usage log DTO 拆分为用户/管理员两类
- 用户接口不返回 account_rate_multiplier/ip_address/account
- 管理员接口保留管理员字段
- 补充契约测试防止回归
2026-01-19 17:05:42 +08:00
shaw
090c8981dd fix: 更新Claude OAuth授权配置以匹配最新规范
- 更新TokenURL和RedirectURI为platform.claude.com
- 更新scope定义,区分浏览器URL和内部API调用
- 修正state/code_verifier生成算法使用base64url编码
- 修正授权URL参数顺序并添加code=true
- 更新token交换请求头匹配官方实现
- 清理未使用的类型和函数
2026-01-19 16:40:06 +08:00
shaw
fbb572948d fix: 修复会话数量查询使用错误的超时配置 2026-01-19 11:45:04 +08:00
77 changed files with 4646 additions and 717 deletions

View File

@@ -84,7 +84,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
accountRepository := repository.NewAccountRepository(client, db)
schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
@@ -105,21 +106,22 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
httpUpstream := repository.NewHTTPUpstream(configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
identityCache := repository.NewIdentityCache(redisClient)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -128,7 +130,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
schedulerCache := repository.NewSchedulerCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
@@ -137,7 +138,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)

View File

@@ -45,6 +45,7 @@ type AccountHandler struct {
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
tokenCacheInvalidator service.TokenCacheInvalidator
}
// NewAccountHandler creates a new admin account handler
@@ -60,6 +61,7 @@ func NewAccountHandler(
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
tokenCacheInvalidator service.TokenCacheInvalidator,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
@@ -73,6 +75,7 @@ func NewAccountHandler(
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
tokenCacheInvalidator: tokenCacheInvalidator,
}
}
@@ -173,6 +176,7 @@ func (h *AccountHandler) List(c *gin.Context) {
// 识别需要查询窗口费用和会话数的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
if acc.IsAnthropicOAuthOrSetupToken() {
@@ -181,6 +185,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
}
}
@@ -189,9 +194,9 @@ func (h *AccountHandler) List(c *gin.Context) {
var windowCosts map[int64]float64
var activeSessions map[int64]int
// 获取活跃会话数(批量查询)
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil {
activeSessions = make(map[int64]int)
}
@@ -606,6 +611,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return
}
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
if h.tokenCacheInvalidator != nil {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程
_ = c.Error(invalidateErr)
}
}
response.Success(c, dto.AccountFromService(updatedAccount))
}

View File

@@ -94,9 +94,9 @@ func (h *GroupHandler) List(c *gin.Context) {
return
}
outGroups := make([]dto.Group, 0, len(groups))
outGroups := make([]dto.AdminGroup, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
}
response.Paginated(c, outGroups, total, page, pageSize)
}
@@ -120,9 +120,9 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
return
}
outGroups := make([]dto.Group, 0, len(groups))
outGroups := make([]dto.AdminGroup, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
}
response.Success(c, outGroups)
}
@@ -142,7 +142,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Create handles creating a new group
@@ -177,7 +177,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Update handles updating a group
@@ -219,7 +219,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Delete handles deleting a group

View File

@@ -54,9 +54,9 @@ func (h *RedeemHandler) List(c *gin.Context) {
return
}
out := make([]dto.RedeemCode, 0, len(codes))
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -76,7 +76,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.RedeemCodeFromService(code))
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
}
// Generate handles generating new redeem codes
@@ -100,9 +100,9 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return
}
out := make([]dto.RedeemCode, 0, len(codes))
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Success(c, out)
}
@@ -163,7 +163,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
return
}
response.Success(c, dto.RedeemCodeFromService(code))
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
}
// GetStats handles getting redeem code statistics

View File

@@ -47,6 +47,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
@@ -68,6 +69,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
@@ -89,6 +91,7 @@ type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -111,13 +114,14 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -238,6 +242,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
@@ -259,6 +264,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
@@ -311,6 +317,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
@@ -332,6 +339,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
HomeContent: updatedSettings.HomeContent,
HideCcsImportButton: updatedSettings.HideCcsImportButton,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -439,6 +447,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.HomeContent != after.HomeContent {
changed = append(changed, "home_content")
}
if before.HideCcsImportButton != after.HideCcsImportButton {
changed = append(changed, "hide_ccs_import_button")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}

View File

@@ -53,9 +53,9 @@ type BulkAssignSubscriptionRequest struct {
Notes string `json:"notes"`
}
// ExtendSubscriptionRequest represents extend subscription request
type ExtendSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten)
type AdjustSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend
}
// List handles listing all subscriptions with pagination and filters
@@ -83,9 +83,9 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
@@ -105,7 +105,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// GetProgress handles getting subscription usage progress
@@ -150,7 +150,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// BulkAssign handles bulk assigning subscriptions to multiple users
@@ -180,7 +180,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
response.Success(c, dto.BulkAssignResultFromService(result))
}
// Extend handles extending a subscription
// Extend handles adjusting a subscription (extend or shorten)
// POST /api/v1/admin/subscriptions/:id/extend
func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -189,7 +189,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
var req ExtendSubscriptionRequest
var req AdjustSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
@@ -201,7 +201,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// Revoke handles revoking a subscription
@@ -239,9 +239,9 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
@@ -261,9 +261,9 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.Success(c, out)
}

View File

@@ -163,7 +163,7 @@ func (h *UsageHandler) List(c *gin.Context) {
return
}
out := make([]dto.UsageLog, 0, len(records))
out := make([]dto.AdminUsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
}

View File

@@ -84,9 +84,9 @@ func (h *UserHandler) List(c *gin.Context) {
return
}
out := make([]dto.User, 0, len(users))
out := make([]dto.AdminUser, 0, len(users))
for i := range users {
out = append(out, *dto.UserFromService(&users[i]))
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -129,7 +129,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Create handles creating a new user
@@ -155,7 +155,7 @@ func (h *UserHandler) Create(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Update handles updating a user
@@ -189,7 +189,7 @@ func (h *UserHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Delete handles deleting a user
@@ -231,7 +231,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// GetUserAPIKeys handles getting user's API keys

View File

@@ -195,6 +195,15 @@ type ValidatePromoCodeResponse struct {
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
// 检查优惠码功能是否启用
if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: "PROMO_CODE_DISABLED",
})
return
}
var req ValidatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())

View File

@@ -15,7 +15,6 @@ func UserFromServiceShallow(u *service.User) *User {
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
@@ -48,6 +47,22 @@ func UserFromService(u *service.User) *User {
return out
}
// UserFromServiceAdmin converts a service User to DTO for admin users.
// It includes notes - user-facing endpoints must not use this.
func UserFromServiceAdmin(u *service.User) *AdminUser {
if u == nil {
return nil
}
base := UserFromService(u)
if base == nil {
return nil
}
return &AdminUser{
User: *base,
Notes: u.Notes,
}
}
func APIKeyFromService(k *service.APIKey) *APIKey {
if k == nil {
return nil
@@ -72,36 +87,29 @@ func GroupFromServiceShallow(g *service.Group) *Group {
if g == nil {
return nil
}
return &Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
}
out := groupFromServiceBase(g)
return &out
}
func GroupFromService(g *service.Group) *Group {
if g == nil {
return nil
}
out := GroupFromServiceShallow(g)
return GroupFromServiceShallow(g)
}
// GroupFromServiceAdmin converts a service Group to DTO for admin users.
// It includes internal fields like model_routing and account_count.
func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
if g == nil {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
AccountCount: g.AccountCount,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
for i := range g.AccountGroups {
@@ -112,6 +120,29 @@ func GroupFromService(g *service.Group) *Group {
return out
}
func groupFromServiceBase(g *service.Group) Group {
return Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}
func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
@@ -273,7 +304,24 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
if rc == nil {
return nil
}
return &RedeemCode{
out := redeemCodeFromServiceBase(rc)
return &out
}
// RedeemCodeFromServiceAdmin converts a service RedeemCode to DTO for admin users.
// It includes notes - user-facing endpoints must not use this.
func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
if rc == nil {
return nil
}
return &AdminRedeemCode{
RedeemCode: redeemCodeFromServiceBase(rc),
Notes: rc.Notes,
}
}
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
return RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
@@ -281,7 +329,6 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
Status: rc.Status,
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
Notes: rc.Notes,
CreatedAt: rc.CreatedAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
@@ -302,14 +349,9 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
}
}
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included.
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
if l == nil {
return nil
}
result := &UsageLog{
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
// 普通用户 DTO严禁包含管理员字段例如 account_rate_multiplier、ip_address、account
return UsageLog{
ID: l.ID,
UserID: l.UserID,
APIKeyID: l.APIKeyID,
@@ -331,7 +373,6 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
AccountRateMultiplier: l.AccountRateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,
@@ -342,30 +383,33 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),
Account: account,
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
// IP 地址仅对管理员可见
if includeIPAddress {
result.IPAddress = l.IPAddress
}
return result
}
// UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details and IP address - users should not see these.
func UsageLogFromService(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, nil, false)
if l == nil {
return nil
}
u := usageLogFromServiceUser(l)
return &u
}
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only) and IP address.
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
if l == nil {
return nil
}
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
return &AdminUsageLog{
UsageLog: usageLogFromServiceUser(l),
AccountRateMultiplier: l.AccountRateMultiplier,
IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account),
}
}
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
@@ -414,7 +458,27 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
if sub == nil {
return nil
}
return &UserSubscription{
out := userSubscriptionFromServiceBase(sub)
return &out
}
// UserSubscriptionFromServiceAdmin converts a service UserSubscription to DTO for admin users.
// It includes assignment metadata and notes.
func UserSubscriptionFromServiceAdmin(sub *service.UserSubscription) *AdminUserSubscription {
if sub == nil {
return nil
}
return &AdminUserSubscription{
UserSubscription: userSubscriptionFromServiceBase(sub),
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
func userSubscriptionFromServiceBase(sub *service.UserSubscription) UserSubscription {
return UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
GroupID: sub.GroupID,
@@ -427,14 +491,10 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
DailyUsageUSD: sub.DailyUsageUSD,
WeeklyUsageUSD: sub.WeeklyUsageUSD,
MonthlyUsageUSD: sub.MonthlyUsageUSD,
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
CreatedAt: sub.CreatedAt,
UpdatedAt: sub.UpdatedAt,
User: UserFromServiceShallow(sub.User),
Group: GroupFromServiceShallow(sub.Group),
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
@@ -442,9 +502,9 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
if r == nil {
return nil
}
subs := make([]UserSubscription, 0, len(r.Subscriptions))
subs := make([]AdminUserSubscription, 0, len(r.Subscriptions))
for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
}
return &BulkAssignResult{
SuccessCount: r.SuccessCount,

View File

@@ -4,6 +4,7 @@ package dto
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -22,13 +23,14 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
@@ -54,6 +56,7 @@ type SystemSettings struct {
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
@@ -63,6 +66,7 @@ type PublicSettings struct {
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}

View File

@@ -6,7 +6,6 @@ type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
@@ -19,6 +18,14 @@ type User struct {
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
// AdminUser 是管理员接口使用的 user DTO包含敏感/内部字段)。
// 注意:普通用户接口不得返回 notes 等管理员备注信息。
type AdminUser struct {
User
Notes string `json:"notes"`
}
type APIKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
@@ -58,13 +65,19 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AdminGroup 是管理员接口使用的 group DTO包含敏感/内部字段)。
// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。
type AdminGroup struct {
Group
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
}
@@ -180,7 +193,6 @@ type RedeemCode struct {
Status string `json:"status"`
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
GroupID *int64 `json:"group_id"`
@@ -190,6 +202,15 @@ type RedeemCode struct {
Group *Group `json:"group,omitempty"`
}
// AdminRedeemCode 是管理员接口使用的 redeem code DTO包含 notes 等字段)。
// 注意:普通用户接口不得返回 notes 等内部信息。
type AdminRedeemCode struct {
RedeemCode
Notes string `json:"notes"`
}
// UsageLog 是普通用户接口使用的 usage log DTO不包含管理员字段
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
@@ -209,14 +230,13 @@ type UsageLog struct {
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
@@ -230,18 +250,28 @@ type UsageLog struct {
// User-Agent
UserAgent *string `json:"user_agent"`
// IP 地址(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
APIKey *APIKey `json:"api_key,omitempty"`
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`
}
// AdminUsageLog 是管理员接口使用的 usage log DTO包含管理员字段
type AdminUsageLog struct {
UsageLog
// AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
// IPAddress 用户请求 IP仅管理员可见
IPAddress *string `json:"ip_address,omitempty"`
// Account 最小账号信息(避免泄露敏感字段)
Account *AccountSummary `json:"account,omitempty"`
}
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
@@ -300,23 +330,30 @@ type UserSubscription struct {
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
// AdminUserSubscription 是管理员接口使用的订阅 DTO包含分配信息/备注等字段)。
// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。
type AdminUserSubscription struct {
UserSubscription
AssignedBy *int64 `json:"assigned_by"`
AssignedAt time.Time `json:"assigned_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
}
type BulkAssignResult struct {
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []AdminUserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
}
// PromoCode 注册优惠码

View File

@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
@@ -43,6 +44,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
})

View File

@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
// 清空notes字段普通用户不应看到备注
userData.Notes = ""
response.Success(c, dto.UserFromService(userData))
}
@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
// 清空notes字段普通用户不应看到备注
updatedUser.Notes = ""
response.Success(c, dto.UserFromService(updatedUser))
}

View File

@@ -13,20 +13,26 @@ import (
"time"
)
// Claude OAuth Constants (from CRS project)
// Claude OAuth Constants
const (
// OAuth Client ID for Claude
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints
AuthorizeURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
TokenURL = "https://platform.claude.com/v1/oauth/token"
RedirectURI = "https://platform.claude.com/oauth/code/callback"
// Scopes
ScopeProfile = "user:profile"
// Scopes - Browser URL (includes org:create_api_key for user authorization)
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code"
// Scopes - Internal API call (org:create_api_key not supported in API)
ScopeAPI = "user:profile user:inference user:sessions:claude_code"
// Scopes - Setup token (inference only)
ScopeInference = "user:inference"
// Code Verifier character set (RFC 7636 compliant)
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
// Session TTL
SessionTTL = 30 * time.Minute
)
@@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore {
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
@@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
@@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) {
return b, nil
}
// GenerateState generates a random state string for OAuth
// GenerateState generates a random state string for OAuth (base64url encoded)
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
return base64URLEncode(bytes), nil
}
// GenerateSessionID generates a unique session ID
@@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) {
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
// GenerateCodeVerifier generates a PKCE code verifier using character set method
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
const targetLen = 32
charsetLen := len(codeVerifierCharset)
limit := 256 - (256 % charsetLen)
result := make([]byte, 0, targetLen)
randBuf := make([]byte, targetLen*2)
for len(result) < targetLen {
if _, err := rand.Read(randBuf); err != nil {
return "", err
}
for _, b := range randBuf {
if int(b) < limit {
result = append(result, codeVerifierCharset[int(b)%charsetLen])
if len(result) >= targetLen {
break
}
}
}
}
return base64URLEncode(bytes), nil
return base64URLEncode(result), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
@@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string {
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OAuth authorization URL
// BuildAuthorizationURL builds the OAuth authorization URL with correct parameter order
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("scope", scope)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
encodedRedirectURI := url.QueryEscape(RedirectURI)
encodedScope := strings.ReplaceAll(url.QueryEscape(scope), "%20", "+")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
return fmt.Sprintf("%s?code=true&client_id=%s&response_type=code&redirect_uri=%s&scope=%s&code_challenge=%s&code_challenge_method=S256&state=%s",
AuthorizeURL,
ClientID,
encodedRedirectURI,
encodedScope,
codeChallenge,
state,
)
}
// TokenResponse represents the token response from OAuth provider
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
// Organization and Account info from OAuth response
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
Organization *OrgInfo `json:"organization,omitempty"`
Account *AccountInfo `json:"account,omitempty"`
}
@@ -207,31 +217,3 @@ type OrgInfo struct {
type AccountInfo struct {
UUID string `json:"uuid"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: RedirectURI,
CodeVerifier: codeVerifier,
State: state,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
}
}

View File

@@ -305,3 +305,139 @@ func mustParseURL(rawURL string) *url.URL {
}
return u
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
return nil
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}

View File

@@ -39,9 +39,15 @@ import (
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
type accountRepository struct {
client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
// 确保粘性会话能及时感知账号不可用状态。
// Used to proactively sync account snapshot to cache when status changes,
// ensuring sticky sessions can promptly detect unavailable accounts.
schedulerCache service.SchedulerCache
}
type tempUnschedSnapshot struct {
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB)
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
}
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
return &accountRepository{client: client, sql: sqlq}
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
}
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID)
}
return nil
}
@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
//
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
// when account status changes. Called when account is set to error, disabled,
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
// logic can promptly detect the latest account state and avoid using unavailable accounts.
func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
if r == nil || r.schedulerCache == nil || accountID <= 0 {
return
}
account, err := r.GetByID(ctx, accountID)
if err != nil {
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
return
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
}
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -873,6 +905,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
@@ -992,6 +1025,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
if !schedulable {
r.syncSchedulerAccountSnapshot(ctx, id)
}
return nil
}
@@ -1146,6 +1182,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
shouldSync := false
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
shouldSync = true
}
if updates.Schedulable != nil && !*updates.Schedulable {
shouldSync = true
}
if shouldSync {
for _, id := range ids {
r.syncSchedulerAccountSnapshot(ctx, id)
}
}
}
return rows, nil
}

View File

@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
repo *accountRepository
}
type schedulerCacheRecorder struct {
setAccounts []*service.Account
}
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
return nil, false, nil
}
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
return nil
}
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
return nil, nil
}
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
s.setAccounts = append(s.setAccounts, account)
return nil
}
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
return nil
}
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
return true, nil
}
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
return 0, nil
}
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
return nil
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = newAccountRepositoryWithSQL(s.client, tx)
s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestAccountRepoSuite(t *testing.T) {
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
s.Require().Equal("updated", got.Name)
}
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
account.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// 每个 case 重新获取隔离资源
tx := testEntTx(s.T())
client := tx.Client()
repo := newAccountRepositoryWithSQL(client, tx)
repo := newAccountRepositoryWithSQL(client, tx, nil)
ctx := context.Background()
tt.setup(client)
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().False(got.Schedulable)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
}
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
disabled := service.StatusDisabled
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
Status: &disabled,
})
s.Require().NoError(err)
s.Require().Equal(int64(2), rows)
s.Require().Len(cacheRecorder.setAccounts, 2)
ids := map[int64]struct{}{}
for _, acc := range cacheRecorder.setAccounts {
ids[acc.ID] = struct{}{}
}
s.Require().Contains(ids, account1.ID)
s.Require().Contains(ids, account2.ID)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---

View File

@@ -182,7 +182,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json, text/plain, */*").
SetHeader("Content-Type", "application/json").
SetHeader("User-Agent", "axios/1.8.4").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
@@ -205,8 +207,6 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
// Anthropic OAuth API 期望 JSON 格式的请求体
reqBody := map[string]any{
"grant_type": "refresh_token",
"refresh_token": refreshToken,
@@ -217,7 +217,9 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json, text/plain, */*").
SetHeader("Content-Type", "application/json").
SetHeader("User-Agent", "axios/1.8.4").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)

View File

@@ -171,7 +171,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
if tt.wantErr {
require.Error(s.T(), err)

View File

@@ -14,37 +14,82 @@ import (
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
// 默认 User-Agent与用户抓包的请求一致
const defaultUsageUserAgent = "claude-code/2.1.7"
type claudeUsageService struct {
usageURL string
allowPrivateHosts bool
httpUpstream service.HTTPUpstream
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
// NewClaudeUsageFetcher 创建 Claude 用量获取服务
// httpUpstream: 可选,如果提供则支持 TLS 指纹伪装
func NewClaudeUsageFetcher(httpUpstream service.HTTPUpstream) service.ClaudeUsageFetcher {
return &claudeUsageService{
usageURL: defaultClaudeUsageURL,
httpUpstream: httpUpstream,
}
}
// FetchUsage 简单版本,不支持 TLS 指纹(向后兼容)
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
return s.FetchUsageWithOptions(ctx, &service.ClaudeUsageFetchOptions{
AccessToken: accessToken,
ProxyURL: proxyURL,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
// FetchUsageWithOptions 完整版本,支持 TLS 指纹和自定义 User-Agent
func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *service.ClaudeUsageFetchOptions) (*service.ClaudeUsageResponse, error) {
if opts == nil {
return nil, fmt.Errorf("options is nil")
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// 设置请求头(与抓包一致,但不设置 Accept-Encoding让 Go 自动处理压缩)
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
// 设置 User-Agent优先使用缓存的 Fingerprint否则使用默认值
userAgent := defaultUsageUserAgent
if opts.Fingerprint != nil && opts.Fingerprint.UserAgent != "" {
userAgent = opts.Fingerprint.UserAgent
}
req.Header.Set("User-Agent", userAgent)
var resp *http.Response
// 如果启用 TLS 指纹且有 HTTPUpstream使用 DoWithTLS
if opts.EnableTLSFingerprint && s.httpUpstream != nil {
// accountConcurrency 传 0 使用默认连接池配置usage 请求不需要特殊的并发设置
resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, true)
if err != nil {
return nil, fmt.Errorf("request with TLS fingerprint failed: %w", err)
}
} else {
// 不启用 TLS 指纹,使用普通 HTTP 客户端
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: opts.ProxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
resp, err = client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
}
defer func() { _ = resp.Body.Close() }()

View File

@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
// 以便下次请求能够重新选择可用账号。
//
// DeleteSessionAccountID removes the sticky session binding for the given session.
// Called when the bound account becomes unavailable (e.g., error status, disabled,
// or unschedulable), allowing subsequent requests to select a new available account.
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
sessionID := "openai:s4"
accountID := int64(102)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
groupID := int64(1)

View File

@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestGatewayRoutingSuite(t *testing.T) {

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -21,7 +22,7 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
client := createOpenAIReqClient(s.tokenURL, proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -54,7 +55,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
client := createOpenAIReqClient(s.tokenURL, proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -81,9 +82,14 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil
}
func createOpenAIReqClient(proxyURL string) *req.Client {
func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
forceHTTP2 := false
if parsedURL, err := url.Parse(tokenURL); err == nil {
forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
}
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
ForceHTTP2: forceHTTP2,
})
}

View File

@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
require.ErrorContains(s.T(), err, "status 401")
}
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
client := NewOpenAIOAuthClient()
svc, ok := client.(*openaiOAuthService)
require.True(t, ok)
require.Equal(t, openai.TokenURL, svc.tokenURL)
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}

View File

@@ -14,6 +14,7 @@ type reqClientOptions struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹
ForceHTTP2 bool // 是否强制使用 HTTP/2
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
client := req.C().SetTimeout(opts.Timeout)
if opts.ForceHTTP2 {
client = client.EnableForceHTTP2()
}
if opts.Impersonate {
client = client.ImpersonateChrome()
}
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
func buildReqClientKey(opts reqClientOptions) string {
return fmt.Sprintf("%s|%s|%t",
return fmt.Sprintf("%s|%s|%t|%t",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.Impersonate,
opts.ForceHTTP2,
)
}

View File

@@ -0,0 +1,102 @@
package repository
import (
"reflect"
"sync"
"testing"
"time"
"unsafe"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)
func forceHTTPVersion(t *testing.T, client *req.Client) string {
t.Helper()
transport := client.GetTransport()
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
require.True(t, field.IsValid(), "forceHttpVersion field not found")
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
}
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
sharedReqClients = sync.Map{}
base := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: time.Second,
}
clientDefault := getSharedReqClient(base)
force := base
force.ForceHTTP2 = true
clientForce := getSharedReqClient(force)
require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
}
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second,
}
first := getSharedReqClient(opts)
second := getSharedReqClient(opts)
require.Same(t, first, second)
}
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 3 * time.Second,
}
key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid")
client := getSharedReqClient(opts)
require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key)
require.True(t, ok)
require.IsType(t, "invalid", loaded)
}
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 4 * time.Second,
Impersonate: true,
}
client := getSharedReqClient(opts)
require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
require.Equal(t, "2", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createGeminiReqClient("http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}

View File

@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
cache := NewSchedulerCache(rdb)

View File

@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) {
if len(accountIDs) == 0 {
return make(map[int64]int), nil
}
@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
// 使用 pipeline 批量执行
pipe := c.rdb.Pipeline()
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
for _, accountID := range accountIDs {
key := sessionLimitKey(accountID)
// 使用各账号自己的 idleTimeout如果没有则用默认值
idleTimeout := c.defaultIdleTimeout
if idleTimeouts != nil {
if t, ok := idleTimeouts[accountID]; ok && t > 0 {
idleTimeout = t
}
}
idleTimeoutSeconds := int(idleTimeout.Seconds())
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
}

View File

@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
"id": 1,
"email": "alice@example.com",
"username": "alice",
"notes": "hello",
"role": "user",
"balance": 12.5,
"concurrency": 5,
@@ -131,6 +130,153 @@ func TestAPIContracts(t *testing.T) {
}
}`,
},
{
name: "GET /api/v1/groups/available",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
// 普通用户可见的分组列表不应包含内部字段(如 model_routing/account_count
deps.groupRepo.SetActive([]service.Group{
{
ID: 10,
Name: "Group One",
Description: "desc",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.5,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-3-*": []int64{101, 102},
},
AccountCount: 2,
CreatedAt: deps.now,
UpdatedAt: deps.now,
},
})
deps.userSubRepo.SetActiveByUserID(1, nil)
},
method: http.MethodGet,
path: "/api/v1/groups/available",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": [
{
"id": 10,
"name": "Group One",
"description": "desc",
"platform": "anthropic",
"rate_multiplier": 1.5,
"is_exclusive": false,
"status": "active",
"subscription_type": "standard",
"daily_limit_usd": null,
"weekly_limit_usd": null,
"monthly_limit_usd": null,
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"claude_code_only": false,
"fallback_group_id": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`,
},
{
name: "GET /api/v1/subscriptions",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
deps.userSubRepo.SetByUserID(1, []service.UserSubscription{
{
ID: 501,
UserID: 1,
GroupID: 10,
StartsAt: deps.now,
ExpiresAt: deps.now.Add(24 * time.Hour),
Status: service.SubscriptionStatusActive,
DailyUsageUSD: 1.23,
WeeklyUsageUSD: 2.34,
MonthlyUsageUSD: 3.45,
AssignedBy: ptr(int64(999)),
AssignedAt: deps.now,
Notes: "admin-note",
CreatedAt: deps.now,
UpdatedAt: deps.now,
},
})
},
method: http.MethodGet,
path: "/api/v1/subscriptions",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": [
{
"id": 501,
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
"expires_at": "2025-01-03T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
"monthly_window_start": null,
"daily_usage_usd": 1.23,
"weekly_usage_usd": 2.34,
"monthly_usage_usd": 3.45,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`,
},
{
name: "GET /api/v1/redeem/history",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
// 普通用户兑换历史不应包含 notes 等内部字段。
deps.redeemRepo.SetByUser(1, []service.RedeemCode{
{
ID: 900,
Code: "CODE-123",
Type: service.RedeemTypeBalance,
Value: 1.25,
Status: service.StatusUsed,
UsedBy: ptr(int64(1)),
UsedAt: ptr(deps.now),
Notes: "internal-note",
CreatedAt: deps.now,
},
})
},
method: http.MethodGet,
path: "/api/v1/redeem/history",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": [
{
"id": 900,
"code": "CODE-123",
"type": "balance",
"value": 1.25,
"status": "used",
"used_by": 1,
"used_at": "2025-01-02T03:04:05Z",
"created_at": "2025-01-02T03:04:05Z",
"group_id": null,
"validity_days": 0
}
]
}`,
},
{
name: "GET /api/v1/usage/stats",
setup: func(t *testing.T, deps *contractDeps) {
@@ -190,24 +336,25 @@ func TestAPIContracts(t *testing.T) {
t.Helper()
deps.usageRepo.SetUserLogs(1, []service.UsageLog{
{
ID: 1,
UserID: 1,
APIKeyID: 100,
AccountID: 200,
RequestID: "req_123",
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 1,
CacheReadTokens: 2,
TotalCost: 0.5,
ActualCost: 0.5,
RateMultiplier: 1,
BillingType: service.BillingTypeBalance,
Stream: true,
DurationMs: ptr(100),
FirstTokenMs: ptr(50),
CreatedAt: deps.now,
ID: 1,
UserID: 1,
APIKeyID: 100,
AccountID: 200,
AccountRateMultiplier: ptr(0.5),
RequestID: "req_123",
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 1,
CacheReadTokens: 2,
TotalCost: 0.5,
ActualCost: 0.5,
RateMultiplier: 1,
BillingType: service.BillingTypeBalance,
Stream: true,
DurationMs: ptr(100),
FirstTokenMs: ptr(50),
CreatedAt: deps.now,
},
})
},
@@ -238,10 +385,9 @@ func TestAPIContracts(t *testing.T) {
"output_cost": 0,
"cache_creation_cost": 0,
"cache_read_cost": 0,
"total_cost": 0.5,
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
@@ -266,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
@@ -304,6 +451,7 @@ func TestAPIContracts(t *testing.T) {
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
@@ -337,7 +485,8 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"home_content": ""
"home_content": "",
"hide_ccs_import_button": false
}
}`,
},
@@ -385,8 +534,11 @@ type contractDeps struct {
now time.Time
router http.Handler
apiKeyRepo *stubApiKeyRepo
groupRepo *stubGroupRepo
userSubRepo *stubUserSubscriptionRepo
usageRepo *stubUsageLogRepo
settingRepo *stubSettingRepo
redeemRepo *stubRedeemCodeRepo
}
func newContractDeps(t *testing.T) *contractDeps {
@@ -414,11 +566,11 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyRepo := newStubApiKeyRepo(now)
apiKeyCache := stubApiKeyCache{}
groupRepo := stubGroupRepo{}
userSubRepo := stubUserSubscriptionRepo{}
groupRepo := &stubGroupRepo{}
userSubRepo := &stubUserSubscriptionRepo{}
accountRepo := stubAccountRepo{}
proxyRepo := stubProxyRepo{}
redeemRepo := stubRedeemCodeRepo{}
redeemRepo := &stubRedeemCodeRepo{}
cfg := &config.Config{
Default: config.DefaultConfig{
@@ -433,6 +585,12 @@ func newContractDeps(t *testing.T) *contractDeps {
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
redeemHandler := handler.NewRedeemHandler(redeemService)
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
@@ -441,7 +599,7 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
@@ -472,12 +630,21 @@ func newContractDeps(t *testing.T) *contractDeps {
v1Keys.Use(jwtAuth)
v1Keys.GET("/keys", apiKeyHandler.List)
v1Keys.POST("/keys", apiKeyHandler.Create)
v1Keys.GET("/groups/available", apiKeyHandler.GetAvailableGroups)
v1Usage := v1.Group("")
v1Usage.Use(jwtAuth)
v1Usage.GET("/usage", usageHandler.List)
v1Usage.GET("/usage/stats", usageHandler.Stats)
v1Subs := v1.Group("")
v1Subs.Use(jwtAuth)
v1Subs.GET("/subscriptions", subscriptionHandler.List)
v1Redeem := v1.Group("")
v1Redeem.Use(jwtAuth)
v1Redeem.GET("/redeem/history", redeemHandler.GetHistory)
v1Admin := v1.Group("/admin")
v1Admin.Use(adminAuth)
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
@@ -487,8 +654,11 @@ func newContractDeps(t *testing.T) *contractDeps {
now: now,
router: r,
apiKeyRepo: apiKeyRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
usageRepo: usageRepo,
settingRepo: settingRepo,
redeemRepo: redeemRepo,
}
}
@@ -626,7 +796,13 @@ func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handl
return nil
}
type stubGroupRepo struct{}
type stubGroupRepo struct {
active []service.Group
}
func (r *stubGroupRepo) SetActive(groups []service.Group) {
r.active = append([]service.Group(nil), groups...)
}
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
return errors.New("not implemented")
@@ -660,12 +836,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi
return nil, nil, errors.New("not implemented")
}
func (stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
return nil, errors.New("not implemented")
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
return append([]service.Group(nil), r.active...), nil
}
func (stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return nil, errors.New("not implemented")
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
out := make([]service.Group, 0, len(r.active))
for i := range r.active {
g := r.active[i]
if g.Platform == platform {
out = append(out, g)
}
}
return out, nil
}
func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
@@ -883,7 +1066,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID
return nil, errors.New("not implemented")
}
type stubRedeemCodeRepo struct{}
type stubRedeemCodeRepo struct {
byUser map[int64][]service.RedeemCode
}
func (r *stubRedeemCodeRepo) SetByUser(userID int64, codes []service.RedeemCode) {
if r.byUser == nil {
r.byUser = make(map[int64][]service.RedeemCode)
}
r.byUser[userID] = append([]service.RedeemCode(nil), codes...)
}
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
return errors.New("not implemented")
@@ -921,11 +1113,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination
return nil, nil, errors.New("not implemented")
}
func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
return nil, errors.New("not implemented")
func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
if r.byUser == nil {
return nil, nil
}
codes := r.byUser[userID]
if limit > 0 && len(codes) > limit {
codes = codes[:limit]
}
return append([]service.RedeemCode(nil), codes...), nil
}
type stubUserSubscriptionRepo struct{}
type stubUserSubscriptionRepo struct {
byUser map[int64][]service.UserSubscription
activeByUser map[int64][]service.UserSubscription
}
func (r *stubUserSubscriptionRepo) SetByUserID(userID int64, subs []service.UserSubscription) {
if r.byUser == nil {
r.byUser = make(map[int64][]service.UserSubscription)
}
r.byUser[userID] = append([]service.UserSubscription(nil), subs...)
}
func (r *stubUserSubscriptionRepo) SetActiveByUserID(userID int64, subs []service.UserSubscription) {
if r.activeByUser == nil {
r.activeByUser = make(map[int64][]service.UserSubscription)
}
r.activeByUser[userID] = append([]service.UserSubscription(nil), subs...)
}
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
@@ -945,11 +1161,17 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub
func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
if r.byUser == nil {
return nil, nil
}
return append([]service.UserSubscription(nil), r.byUser[userID]...), nil
}
func (stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
if r.activeByUser == nil {
return nil, nil
}
return append([]service.UserSubscription(nil), r.activeByUser[userID]...), nil
}
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")

View File

@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
} `json:"seven_day_sonnet"`
}
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
type ClaudeUsageFetchOptions struct {
AccessToken string // OAuth access token
ProxyURL string // 代理 URL可选
AccountID int64 // 账号 ID用于 TLS 指纹选择)
EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装
Fingerprint *Fingerprint // 缓存的指纹信息User-Agent 等)
}
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
type ClaudeUsageFetcher interface {
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
FetchUsageWithOptions(ctx context.Context, opts *ClaudeUsageFetchOptions) (*ClaudeUsageResponse, error)
}
// AccountUsageService 账号使用量查询服务
@@ -170,6 +181,7 @@ type AccountUsageService struct {
geminiQuotaService *GeminiQuotaService
antigravityQuotaFetcher *AntigravityQuotaFetcher
cache *UsageCache
identityCache IdentityCache
}
// NewAccountUsageService 创建AccountUsageService实例
@@ -180,6 +192,7 @@ func NewAccountUsageService(
geminiQuotaService *GeminiQuotaService,
antigravityQuotaFetcher *AntigravityQuotaFetcher,
cache *UsageCache,
identityCache IdentityCache,
) *AccountUsageService {
return &AccountUsageService{
accountRepo: accountRepo,
@@ -188,6 +201,7 @@ func NewAccountUsageService(
geminiQuotaService: geminiQuotaService,
antigravityQuotaFetcher: antigravityQuotaFetcher,
cache: cache,
identityCache: identityCache,
}
}
@@ -424,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
}
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
// 如果有缓存的 Fingerprint则使用缓存的 User-Agent 等信息
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
accessToken := account.GetCredential("access_token")
if accessToken == "" {
@@ -435,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
proxyURL = account.Proxy.URL()
}
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
// 构建完整的选项
opts := &ClaudeUsageFetchOptions{
AccessToken: accessToken,
ProxyURL: proxyURL,
AccountID: account.ID,
EnableTLSFingerprint: account.IsTLSFingerprintEnabled(),
}
// 尝试获取缓存的 Fingerprint包含 User-Agent 等信息)
if s.identityCache != nil {
if fp, err := s.identityCache.GetFingerprint(ctx, account.ID); err == nil && fp != nil {
opts.Fingerprint = fp
}
}
return s.usageFetcher.FetchUsageWithOptions(ctx, opts)
}
// parseTime 尝试多种格式解析时间

View File

@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable
}
// 应用优惠码(如果提供)
if promoCode != "" && s.promoService != nil {
// 应用优惠码(如果提供且功能已启用
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)

View File

@@ -71,6 +71,7 @@ const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
@@ -93,13 +94,14 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML或 URL 作为 iframe src
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML或 URL 作为 iframe src
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量

File diff suppressed because it is too large Load Diff

View File

@@ -99,11 +99,24 @@ var allowedHeaders = map[string]bool{
"content-type": true,
}
// GatewayCache defines cache operations for gateway service
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话Sticky Session的存储、查询、刷新和删除功能。
//
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
// GetSessionAccountID 获取粘性会话绑定的账号 ID
// Get the account ID bound to a sticky session
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
// SetSessionAccountID 设置粘性会话与账号的绑定关系
// Set the binding between sticky session and account
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
// RefreshSessionTTL 刷新粘性会话的过期时间
// Refresh the expiration time of a sticky session
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -114,6 +127,28 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account) bool {
if account == nil {
return false
}
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
return true
}
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
return false
}
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
@@ -658,6 +693,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
} else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
}
}
@@ -764,41 +801,52 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID]
if ok && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
} else {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
} else {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
@@ -1418,14 +1466,20 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
}
return account, nil
}
}
}
@@ -1515,11 +1569,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
return account, nil
}
}
}
@@ -1619,15 +1679,21 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性原生平台直接匹配antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
}
}
}
@@ -1718,12 +1784,18 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性原生平台直接匹配antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
return account, nil
}
}
}

View File

@@ -82,70 +82,23 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
var group *Group
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
group = ctxGroup
} else {
var err error
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
}
platform = group.Platform
} else {
// 无分组时只使用原生 gemini 平台
platform = PlatformGemini
// 1. 确定目标平台和调度模式
// Determine target platform and scheduling mode
platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID)
if err != nil {
return nil, err
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
cacheKey := "gemini:" + sessionHash
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false
if account.Platform == platform {
valid = true
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
valid = true
}
if valid {
usable := true
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
if !ok {
usable = false
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
}
}
// 2. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil {
return account, nil
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
// Query schedulable accounts (force platform mode: try group first, fallback to all)
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -158,56 +111,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 混合调度模式下原生平台直接通过antigravity 需要启用 mixed_scheduling
// 非混合调度模式antigravity 分组):不需要过滤
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
continue
}
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
// 4. 按优先级 + LRU 选择最佳账号
// Select best account by priority + LRU
selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling)
if selected == nil {
if requestedModel != "" {
@@ -216,6 +122,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return nil, errors.New("no available Gemini accounts")
}
// 5. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
}
@@ -223,6 +131,229 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return selected, nil
}
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
//
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, false, true, nil
}
if groupID != nil {
// 根据分组 platform 决定查询哪种账号
var group *Group
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
group = ctxGroup
} else {
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil {
return "", false, false, fmt.Errorf("get group failed: %w", err)
}
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
return group.Platform, group.Platform == PlatformGemini, false, nil
}
// 无分组时只使用原生 gemini 平台
return PlatformGemini, true, false, nil
}
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account unavailable.
func (s *GeminiMessagesCompatService) tryStickySessionHit(
ctx context.Context,
groupID *int64,
sessionHash, cacheKey, requestedModel string,
excludedIDs map[int64]struct{},
platform string,
useMixedScheduling bool,
) *Account {
if sessionHash == "" {
return nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err != nil || accountID <= 0 {
return nil
}
if _, excluded := excludedIDs[accountID]; excluded {
return nil
}
account, err := s.getSchedulableAccount(ctx, accountID)
if err != nil {
return nil
}
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) {
return nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account
}
// isAccountUsableForRequest 检查账号是否可用于当前请求。
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
//
// isAccountUsableForRequest checks if account is usable for current request.
// Validates: model scheduling, model support, platform matching, rate limit precheck.
func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
ctx context.Context,
account *Account,
requestedModel, platform string,
useMixedScheduling bool,
) bool {
// 检查模型调度能力
// Check model scheduling capability
if !account.IsSchedulableForModel(requestedModel) {
return false
}
// 检查模型支持
// Check model support
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
return false
}
// 检查平台匹配
// Check platform matching
if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) {
return false
}
// 速率限制预检
// Rate limit precheck
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
return false
}
return true
}
// isAccountValidForPlatform 检查账号是否匹配目标平台。
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
//
// isAccountValidForPlatform checks if account matches target platform.
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool {
if account.Platform == platform {
return true
}
if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
return true
}
return false
}
// passesRateLimitPreCheck 执行速率限制预检。
// 返回 true 表示通过预检或无需预检。
//
// passesRateLimitPreCheck performs rate limit precheck.
// Returns true if passed or precheck not required.
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
if s.rateLimitService == nil || requestedModel == "" {
return true
}
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
return ok
}
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
// 返回 nil 表示无可用账号。
//
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
// Returns nil if no available account.
func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
ctx context.Context,
accounts []Account,
requestedModel string,
excludedIDs map[int64]struct{},
platform string,
useMixedScheduling bool,
) *Account {
var selected *Account
for i := range accounts {
acc := &accounts[i]
// 跳过被排除的账号
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 检查账号是否可用于当前请求
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
continue
}
// 选择最佳账号
if selected == nil {
selected = acc
continue
}
if s.isBetterGeminiAccount(acc, selected) {
selected = acc
}
}
return selected
}
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
// 规则优先级更高数值更小优先同优先级时未使用过的优先OAuth > 非 OAuth其次是最久未使用的。
//
// isBetterGeminiAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool {
// 优先级更高(数值更小)
if candidate.Priority < current.Priority {
return true
}
if candidate.Priority > current.Priority {
return false
}
// 同优先级,比较最后使用时间
switch {
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
// candidate 从未使用,优先
return true
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
// current 从未使用,保持
return false
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth
default:
// 都使用过,选择最久未使用的
return candidate.LastUsedAt.Before(*current.LastUsedAt)
}
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {

View File

@@ -15,8 +15,10 @@ import (
// mockAccountRepoForGemini Gemini 测试用的 mock
type mockAccountRepoForGemini struct {
accounts []Account
accountsByID map[int64]*Account
accounts []Account
accountsByID map[int64]*Account
listByGroupFunc func(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
listByPlatformFunc func(ctx context.Context, platforms []string) ([]Account, error)
}
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
@@ -107,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
return nil, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
if m.listByPlatformFunc != nil {
return m.listByPlatformFunc(ctx, platforms)
}
var result []Account
platformSet := make(map[string]bool)
for _, p := range platforms {
@@ -120,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
return result, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
if m.listByGroupFunc != nil {
return m.listByGroupFunc(ctx, groupID, platforms)
}
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
@@ -215,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64
deletedSessions map[string]int
}
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
@@ -236,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
return nil
}
func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if m.sessionBindings == nil {
return nil
}
if m.deletedSessions == nil {
m.deletedSessions = make(map[string]int)
}
m.deletedSessions[sessionHash]++
delete(m.sessionBindings, sessionHash)
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()
@@ -526,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
// 粘性会话未命中,按优先级选择
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
})
t.Run("粘性会话不可调度-清理并回退选择", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusDisabled, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-123": 1},
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
require.Equal(t, 1, cache.deletedSessions["gemini:session-123"])
require.Equal(t, int64(2), cache.sessionBindings["gemini:session-123"])
})
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback(t *testing.T) {
ctx := context.Background()
groupID := int64(9)
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
repo := &mockAccountRepoForGemini{
listByGroupFunc: func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return nil, nil
},
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
return []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
}, nil
},
accountsByID: map[int64]*Account{
1: {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{
ID: 1,
Platform: PlatformGemini,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.0-pro": "gemini-1.0-pro"}},
},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "supporting model")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-999": 1},
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-999", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
excluded := map[int64]struct{}{1: {}}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", excluded)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
return nil, errors.New("query failed")
},
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "query accounts failed")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed(t *testing.T) {
ctx := context.Background()
oldTime := time.Now().Add(-2 * time.Hour)
newTime := time.Now().Add(-1 * time.Hour)
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &newTime},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &oldTime},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑

View File

@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct {
// GenerateAuthURL generates an OAuth authorization URL with full scope
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
return s.generateAuthURLWithScope(ctx, scope, proxyID)
return s.generateAuthURLWithScope(ctx, oauth.ScopeOAuth, proxyID)
}
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
@@ -176,7 +175,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
}
// Determine scope and if this is a setup token
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
// Internal API call uses ScopeAPI (org:create_api_key not supported)
scope := oauth.ScopeAPI
isSetupToken := false
if input.Scope == "inference" {
scope = oauth.ScopeInference

View File

@@ -180,67 +180,26 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. Check sticky session
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return account, nil
}
}
}
cacheKey := "openai:" + sessionHash
// 1. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil {
return account, nil
}
// 2. Get schedulable OpenAI accounts
// 2. 获取可调度的 OpenAI 账号
// Get schedulable OpenAI accounts
accounts, err := s.listSchedulableAccounts(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. Select by priority + LRU
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
// Lower priority value means higher priority
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
default:
// Same priority, select least recently used
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
// 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
if selected == nil {
if requestedModel != "" {
@@ -249,14 +208,138 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
return nil, errors.New("no available OpenAI accounts")
}
// 4. Set sticky session
// 4. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL)
}
return selected, nil
}
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account {
if sessionHash == "" {
return nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err != nil || accountID <= 0 {
return nil
}
if _, excluded := excludedIDs[accountID]; excluded {
return nil
}
account, err := s.getSchedulableAccount(ctx, accountID)
if err != nil {
return nil
}
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !account.IsSchedulable() || !account.IsOpenAI() {
return nil
}
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL)
return account
}
// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU
// 返回 nil 表示无可用账号。
//
// selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
var selected *Account
for i := range accounts {
acc := &accounts[i]
// 跳过被排除的账号
// Skip excluded accounts
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
if !acc.IsSchedulable() || !acc.IsOpenAI() {
continue
}
// 检查模型支持
// Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
if selected == nil {
selected = acc
continue
}
if s.isBetterAccount(acc, selected) {
selected = acc
}
}
return selected
}
// isBetterAccount 判断 candidate 是否比 current 更优。
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
//
// isBetterAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool {
// 优先级更高(数值更小)
// Higher priority (lower value)
if candidate.Priority < current.Priority {
return true
}
if candidate.Priority > current.Priority {
return false
}
// 同优先级,比较最后使用时间
// Same priority, compare last used time
switch {
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
// candidate 从未使用,优先
return true
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
// current 从未使用,保持
return false
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
// 都未使用,保持
return false
default:
// 都使用过,选择最久未使用的
return candidate.LastUsedAt.Before(*current.LastUsedAt)
}
}
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
@@ -325,29 +408,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
}
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}

View File

@@ -21,19 +21,50 @@ type stubOpenAIAccountRepo struct {
accounts []Account
}
func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
for i := range r.accounts {
if r.accounts[i].ID == id {
return &r.accounts[i], nil
}
}
return nil, errors.New("account not found")
}
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
type stubConcurrencyCache struct {
ConcurrencyCache
loadBatchErr error
loadMap map[int64]*AccountLoadInfo
acquireResults map[int64]bool
waitCounts map[int64]int
skipDefaultLoad bool
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if c.acquireResults != nil {
if result, ok := c.acquireResults[accountID]; ok {
return result, nil
}
}
return true, nil
}
@@ -42,8 +73,25 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID
}
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if c.loadBatchErr != nil {
return nil, c.loadBatchErr
}
out := make(map[int64]*AccountLoadInfo, len(accounts))
if c.skipDefaultLoad && c.loadMap != nil {
for _, acc := range accounts {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
}
}
return out, nil
}
for _, acc := range accounts {
if c.loadMap != nil {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
continue
}
}
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
@@ -92,6 +140,51 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
}
}
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if c.waitCounts != nil {
if count, ok := c.waitCounts[accountID]; ok {
return count, nil
}
}
return 0, nil
}
type stubGatewayCache struct {
sessionBindings map[string]int64
deletedSessions map[string]int
}
func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := c.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if c.sessionBindings == nil {
c.sessionBindings = make(map[string]int64)
}
c.sessionBindings[sessionHash] = accountID
return nil
}
func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}
func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if c.sessionBindings == nil {
return nil
}
if c.deletedSessions == nil {
c.deletedSessions = make(map[string]int)
}
c.deletedSessions[sessionHash]++
delete(c.sessionBindings, sessionHash)
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
@@ -182,6 +275,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
}
}
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
sessionHash := "session-1"
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 2 {
t.Fatalf("expected account 2, got %+v", acc)
}
if cache.deletedSessions["openai:"+sessionHash] != 1 {
t.Fatalf("expected sticky session to be deleted")
}
if cache.sessionBindings["openai:"+sessionHash] != 2 {
t.Fatalf("expected sticky session to bind to account 2")
}
}
func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) {
sessionHash := "session-2"
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
t.Fatalf("expected account 2, got %+v", selection)
}
if cache.deletedSessions["openai:"+sessionHash] != 1 {
t.Fatalf("expected sticky session to be deleted")
}
if cache.sessionBindings["openai:"+sessionHash] != 2 {
t.Fatalf("expected sticky session to bind to account 2")
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
repo := stubOpenAIAccountRepo{
accounts: []Account{
{
ID: 1,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}},
},
},
}
cache := &stubGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
if err == nil {
t.Fatalf("expected error for unsupported model")
}
if acc != nil {
t.Fatalf("expected nil account for unsupported model")
}
if !strings.Contains(err.Error(), "supporting model") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadBatchErr: errors.New("load batch failed"),
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection")
}
if selection.Account.ID != 2 {
t.Fatalf("expected account 2, got %d", selection.Account.ID)
}
if cache.sessionBindings["openai:fallback"] != 2 {
t.Fatalf("expected sticky session updated")
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
acquireResults: map[int64]bool{1: false},
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 10},
},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.WaitPlan == nil {
t.Fatalf("expected wait plan fallback")
}
if selection.Account == nil || selection.Account.ID != 1 {
t.Fatalf("expected account 1")
}
}
func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) {
sessionHash := "bind"
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 1 {
t.Fatalf("expected account 1")
}
if cache.sessionBindings["openai:"+sessionHash] != 1 {
t.Fatalf("expected sticky session binding")
}
}
func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) {
sessionHash := "sticky-wait"
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
}
concurrencyCache := stubConcurrencyCache{
acquireResults: map[int64]bool{1: false},
waitCounts: map[int64]int{1: 0},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.WaitPlan == nil {
t.Fatalf("expected sticky wait plan")
}
if selection.Account == nil || selection.Account.ID != 1 {
t.Fatalf("expected account 1")
}
}
func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 80},
2: {AccountID: 2, LoadRate: 10},
},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
t.Fatalf("expected account 2")
}
if cache.sessionBindings["openai:load"] != 2 {
t.Fatalf("expected sticky session updated")
}
}
func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) {
sessionHash := "excluded"
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
excluded := map[int64]struct{}{1: {}}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 2 {
t.Fatalf("expected account 2")
}
}
func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) {
sessionHash := "non-openai"
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 2 {
t.Fatalf("expected account 2")
}
}
func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) {
repo := stubOpenAIAccountRepo{accounts: []Account{}}
cache := &stubGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil)
if err == nil {
t.Fatalf("expected error for no accounts")
}
if acc != nil {
t.Fatalf("expected nil account")
}
if !strings.Contains(err.Error(), "no available OpenAI accounts") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
groupID := int64(1)
resetAt := time.Now().Add(1 * time.Hour)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err == nil {
t.Fatalf("expected error for no candidates")
}
if selection != nil {
t.Fatalf("expected nil selection")
}
}
func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 100},
},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.WaitPlan == nil {
t.Fatalf("expected wait plan")
}
}
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadBatchErr: errors.New("load batch failed"),
acquireResults: map[int64]bool{1: false},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.WaitPlan == nil {
t.Fatalf("expected wait plan")
}
}
func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 50},
},
skipDefaultLoad: true,
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
t.Fatalf("expected account 2")
}
}
func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) {
oldTime := time.Now().Add(-2 * time.Hour)
newTime := time.Now().Add(-1 * time.Hour)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime},
},
}
cache := &stubGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 2 {
t.Fatalf("expected account 2")
}
}
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
groupID := int64(1)
lastUsed := time.Now().Add(-1 * time.Hour)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 10},
2: {AccountID: 2, LoadRate: 10},
},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
t.Fatalf("expected account 2")
}
}
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{

View File

@@ -38,8 +38,9 @@ type SessionLimitCache interface {
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// idleTimeouts: 每个账号的空闲超时时间配置key 为 accountID若为 nil 或某账号不在其中,则使用默认超时
// 返回 map[accountID]count查询失败的账号不在 map 中
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)

View File

@@ -60,6 +60,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys := []string{
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
SettingKeyPromoCodeEnabled,
SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey,
SettingKeySiteName,
@@ -69,6 +70,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyContactInfo,
SettingKeyDocURL,
SettingKeyHomeContent,
SettingKeyHideCcsImportButton,
SettingKeyLinuxDoConnectEnabled,
}
@@ -87,6 +89,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
@@ -96,6 +99,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil
}
@@ -123,6 +127,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
return &struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
SiteName string `json:"site_name"`
@@ -132,11 +137,13 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ContactInfo string `json:"contact_info,omitempty"`
DocURL string `json:"doc_url,omitempty"`
HomeContent string `json:"home_content,omitempty"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
@@ -146,6 +153,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
}, nil
@@ -158,6 +166,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// 注册设置
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySMTPHost] = settings.SMTPHost
@@ -193,6 +202,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocURL] = settings.DocURL
updates[SettingKeyHomeContent] = settings.HomeContent
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
@@ -243,6 +253,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
return value == "true"
}
// IsPromoCodeEnabled 检查是否启用优惠码功能
func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled)
if err != nil {
return true // 默认启用
}
return value != "false"
}
// GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
@@ -292,6 +311,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
defaults := map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
@@ -323,6 +343,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
@@ -339,6 +360,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
}
// 解析整数类型

View File

@@ -3,6 +3,7 @@ package service
type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
PromoCodeEnabled bool
SMTPHost string
SMTPPort int
@@ -25,13 +26,14 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string
SiteName string
SiteLogo string
SiteSubtitle string
APIBaseURL string
ContactInfo string
DocURL string
HomeContent string
SiteName string
SiteLogo string
SiteSubtitle string
APIBaseURL string
ContactInfo string
DocURL string
HomeContent string
HideCcsImportButton bool
DefaultConcurrency int
DefaultBalance float64
@@ -57,6 +59,7 @@ type SystemSettings struct {
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
PromoCodeEnabled bool
TurnstileEnabled bool
TurnstileSiteKey string
SiteName string
@@ -66,6 +69,7 @@ type PublicSettings struct {
ContactInfo string
DocURL string
HomeContent string
HideCcsImportButton bool
LinuxDoOAuthEnabled bool
Version string
}

View File

@@ -0,0 +1,54 @@
//go:build unit
// Package service 提供 API 网关核心服务。
// 本文件包含 shouldClearStickySession 函数的单元测试,
// 验证粘性会话清理逻辑在各种账号状态下的正确行为。
//
// This file contains unit tests for the shouldClearStickySession function,
// verifying correct sticky session clearing behavior under various account states.
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
// 验证在以下情况下是否正确判断需要清理粘性会话:
// - nil 账号:不清理(返回 false
// - 状态为错误或禁用:清理
// - 不可调度:清理
// - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理
//
// TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including:
// nil account, error/disabled status, unschedulable, temporary unschedulable.
func TestShouldClearStickySession(t *testing.T) {
now := time.Now()
future := now.Add(1 * time.Hour)
past := now.Add(-1 * time.Hour)
tests := []struct {
name string
account *Account
want bool
}{
{name: "nil account", account: nil, want: false},
{name: "status error", account: &Account{Status: StatusError, Schedulable: true}, want: true},
{name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, want: true},
{name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, want: true},
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, want: true},
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, want: false},
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, shouldClearStickySession(tt.account))
})
}
}

View File

@@ -27,6 +27,7 @@ var (
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
)
// SubscriptionService 订阅服务
@@ -308,17 +309,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
return nil
}
// ExtendSubscription 延长订阅
// ExtendSubscription 调整订阅时长(正数延长,负数缩短)
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 限制延长天数
// 限制调整天数范围
if days > MaxValidityDays {
days = MaxValidityDays
}
if days < -MaxValidityDays {
days = -MaxValidityDays
}
// 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
@@ -326,6 +330,14 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
newExpiresAt = MaxExpiresAt
}
// 如果是缩短(负数),检查新的过期时间必须大于当前时间
if days < 0 {
now := time.Now()
if !newExpiresAt.After(now) {
return nil, ErrAdjustWouldExpire
}
}
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
return nil, err
}

View File

@@ -345,6 +345,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.deleteCalls, 3)
require.Equal(t, 2, repo.deleteCalls[0].limit)
require.True(t, repo.deleteCalls[0].filters.StartTime.Equal(start))
require.True(t, repo.deleteCalls[0].filters.EndTime.Equal(end))
require.Len(t, repo.markSucceeded, 1)
require.Empty(t, repo.markFailed)
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)

View File

@@ -408,6 +408,8 @@ sudo systemctl status redis
Sub2API supports TLS fingerprint simulation to make requests appear as if they come from the official Claude CLI (Node.js client).
> **💡 Tip:** Visit **[tls.sub2api.org](https://tls.sub2api.org/)** to get TLS fingerprint information for different devices and browsers.
### Default Behavior
- Built-in `claude_cli_v2` profile simulates Node.js 20.x + OpenSSL 3.x

View File

@@ -5,7 +5,7 @@
import { apiClient } from '../client'
import type {
Group,
AdminGroup,
GroupPlatform,
CreateGroupRequest,
UpdateGroupRequest,
@@ -31,8 +31,8 @@ export async function list(
options?: {
signal?: AbortSignal
}
): Promise<PaginatedResponse<Group>> {
const { data } = await apiClient.get<PaginatedResponse<Group>>('/admin/groups', {
): Promise<PaginatedResponse<AdminGroup>> {
const { data } = await apiClient.get<PaginatedResponse<AdminGroup>>('/admin/groups', {
params: {
page,
page_size: pageSize,
@@ -48,8 +48,8 @@ export async function list(
* @param platform - Optional platform filter
* @returns List of all active groups
*/
export async function getAll(platform?: GroupPlatform): Promise<Group[]> {
const { data } = await apiClient.get<Group[]>('/admin/groups/all', {
export async function getAll(platform?: GroupPlatform): Promise<AdminGroup[]> {
const { data } = await apiClient.get<AdminGroup[]>('/admin/groups/all', {
params: platform ? { platform } : undefined
})
return data
@@ -60,7 +60,7 @@ export async function getAll(platform?: GroupPlatform): Promise<Group[]> {
* @param platform - Platform to filter by
* @returns List of groups for the specified platform
*/
export async function getByPlatform(platform: GroupPlatform): Promise<Group[]> {
export async function getByPlatform(platform: GroupPlatform): Promise<AdminGroup[]> {
return getAll(platform)
}
@@ -69,8 +69,8 @@ export async function getByPlatform(platform: GroupPlatform): Promise<Group[]> {
* @param id - Group ID
* @returns Group details
*/
export async function getById(id: number): Promise<Group> {
const { data } = await apiClient.get<Group>(`/admin/groups/${id}`)
export async function getById(id: number): Promise<AdminGroup> {
const { data } = await apiClient.get<AdminGroup>(`/admin/groups/${id}`)
return data
}
@@ -79,8 +79,8 @@ export async function getById(id: number): Promise<Group> {
* @param groupData - Group data
* @returns Created group
*/
export async function create(groupData: CreateGroupRequest): Promise<Group> {
const { data } = await apiClient.post<Group>('/admin/groups', groupData)
export async function create(groupData: CreateGroupRequest): Promise<AdminGroup> {
const { data } = await apiClient.post<AdminGroup>('/admin/groups', groupData)
return data
}
@@ -90,8 +90,8 @@ export async function create(groupData: CreateGroupRequest): Promise<Group> {
* @param updates - Fields to update
* @returns Updated group
*/
export async function update(id: number, updates: UpdateGroupRequest): Promise<Group> {
const { data } = await apiClient.put<Group>(`/admin/groups/${id}`, updates)
export async function update(id: number, updates: UpdateGroupRequest): Promise<AdminGroup> {
const { data } = await apiClient.put<AdminGroup>(`/admin/groups/${id}`, updates)
return data
}
@@ -111,7 +111,7 @@ export async function deleteGroup(id: number): Promise<{ message: string }> {
* @param status - New status
* @returns Updated group
*/
export async function toggleStatus(id: number, status: 'active' | 'inactive'): Promise<Group> {
export async function toggleStatus(id: number, status: 'active' | 'inactive'): Promise<AdminGroup> {
return update(id, { status })
}

View File

@@ -12,6 +12,7 @@ export interface SystemSettings {
// Registration settings
registration_enabled: boolean
email_verify_enabled: boolean
promo_code_enabled: boolean
// Default settings
default_balance: number
default_concurrency: number
@@ -23,6 +24,7 @@ export interface SystemSettings {
contact_info: string
doc_url: string
home_content: string
hide_ccs_import_button: boolean
// SMTP settings
smtp_host: string
smtp_port: number
@@ -63,6 +65,7 @@ export interface SystemSettings {
export interface UpdateSettingsRequest {
registration_enabled?: boolean
email_verify_enabled?: boolean
promo_code_enabled?: boolean
default_balance?: number
default_concurrency?: number
site_name?: string
@@ -72,6 +75,7 @@ export interface UpdateSettingsRequest {
contact_info?: string
doc_url?: string
home_content?: string
hide_ccs_import_button?: boolean
smtp_host?: string
smtp_port?: number
smtp_username?: string

View File

@@ -4,7 +4,7 @@
*/
import { apiClient } from '../client'
import type { UsageLog, UsageQueryParams, PaginatedResponse } from '@/types'
import type { AdminUsageLog, UsageQueryParams, PaginatedResponse } from '@/types'
// ==================== Types ====================
@@ -85,8 +85,8 @@ export interface AdminUsageQueryParams extends UsageQueryParams {
export async function list(
params: AdminUsageQueryParams,
options?: { signal?: AbortSignal }
): Promise<PaginatedResponse<UsageLog>> {
const { data } = await apiClient.get<PaginatedResponse<UsageLog>>('/admin/usage', {
): Promise<PaginatedResponse<AdminUsageLog>> {
const { data } = await apiClient.get<PaginatedResponse<AdminUsageLog>>('/admin/usage', {
params,
signal: options?.signal
})

View File

@@ -4,7 +4,7 @@
*/
import { apiClient } from '../client'
import type { User, UpdateUserRequest, PaginatedResponse } from '@/types'
import type { AdminUser, UpdateUserRequest, PaginatedResponse } from '@/types'
/**
* List all users with pagination
@@ -26,7 +26,7 @@ export async function list(
options?: {
signal?: AbortSignal
}
): Promise<PaginatedResponse<User>> {
): Promise<PaginatedResponse<AdminUser>> {
// Build params with attribute filters in attr[id]=value format
const params: Record<string, any> = {
page,
@@ -44,8 +44,7 @@ export async function list(
}
}
}
const { data } = await apiClient.get<PaginatedResponse<User>>('/admin/users', {
const { data } = await apiClient.get<PaginatedResponse<AdminUser>>('/admin/users', {
params,
signal: options?.signal
})
@@ -57,8 +56,8 @@ export async function list(
* @param id - User ID
* @returns User details
*/
export async function getById(id: number): Promise<User> {
const { data } = await apiClient.get<User>(`/admin/users/${id}`)
export async function getById(id: number): Promise<AdminUser> {
const { data } = await apiClient.get<AdminUser>(`/admin/users/${id}`)
return data
}
@@ -73,8 +72,8 @@ export async function create(userData: {
balance?: number
concurrency?: number
allowed_groups?: number[] | null
}): Promise<User> {
const { data } = await apiClient.post<User>('/admin/users', userData)
}): Promise<AdminUser> {
const { data } = await apiClient.post<AdminUser>('/admin/users', userData)
return data
}
@@ -84,8 +83,8 @@ export async function create(userData: {
* @param updates - Fields to update
* @returns Updated user
*/
export async function update(id: number, updates: UpdateUserRequest): Promise<User> {
const { data } = await apiClient.put<User>(`/admin/users/${id}`, updates)
export async function update(id: number, updates: UpdateUserRequest): Promise<AdminUser> {
const { data } = await apiClient.put<AdminUser>(`/admin/users/${id}`, updates)
return data
}
@@ -112,8 +111,8 @@ export async function updateBalance(
balance: number,
operation: 'set' | 'add' | 'subtract' = 'set',
notes?: string
): Promise<User> {
const { data } = await apiClient.post<User>(`/admin/users/${id}/balance`, {
): Promise<AdminUser> {
const { data } = await apiClient.post<AdminUser>(`/admin/users/${id}/balance`, {
balance,
operation,
notes: notes || ''
@@ -127,7 +126,7 @@ export async function updateBalance(
* @param concurrency - New concurrency limit
* @returns Updated user
*/
export async function updateConcurrency(id: number, concurrency: number): Promise<User> {
export async function updateConcurrency(id: number, concurrency: number): Promise<AdminUser> {
return update(id, { concurrency })
}
@@ -137,7 +136,7 @@ export async function updateConcurrency(id: number, concurrency: number): Promis
* @param status - New status
* @returns Updated user
*/
export async function toggleStatus(id: number, status: 'active' | 'disabled'): Promise<User> {
export async function toggleStatus(id: number, status: 'active' | 'disabled'): Promise<AdminUser> {
return update(id, { status })
}

View File

@@ -648,7 +648,7 @@ import { ref, watch, computed } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin'
import type { Proxy, Group } from '@/types'
import type { Proxy, AdminGroup } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import Select from '@/components/common/Select.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
@@ -659,7 +659,7 @@ interface Props {
show: boolean
accountIds: number[]
proxies: Proxy[]
groups: Group[]
groups: AdminGroup[]
}
const props = defineProps<Props>()

View File

@@ -1816,7 +1816,7 @@ import {
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
import type { Proxy, Group, AccountPlatform, AccountType } from '@/types'
import type { Proxy, AdminGroup, AccountPlatform, AccountType } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import Icon from '@/components/icons/Icon.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
@@ -1862,7 +1862,7 @@ const apiKeyHint = computed(() => {
interface Props {
show: boolean
proxies: Proxy[]
groups: Group[]
groups: AdminGroup[]
}
const props = defineProps<Props>()

View File

@@ -883,7 +883,7 @@ import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { useAuthStore } from '@/stores/auth'
import { adminAPI } from '@/api/admin'
import type { Account, Proxy, Group } from '@/types'
import type { Account, Proxy, AdminGroup } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import Select from '@/components/common/Select.vue'
import Icon from '@/components/icons/Icon.vue'
@@ -901,7 +901,7 @@ interface Props {
show: boolean
account: Account | null
proxies: Proxy[]
groups: Group[]
groups: AdminGroup[]
}
const props = defineProps<Props>()

View File

@@ -4,6 +4,7 @@
<button @click="$emit('refresh')" :disabled="loading" class="btn btn-secondary">
<Icon name="refresh" size="md" :class="[loading ? 'animate-spin' : '']" />
</button>
<slot name="after"></slot>
<button @click="$emit('sync')" class="btn btn-secondary">{{ t('admin.accounts.syncFromCrs') }}</button>
<button @click="$emit('create')" class="btn btn-primary">{{ t('admin.accounts.createAccount') }}</button>
</div>

View File

@@ -239,7 +239,7 @@ import { formatDateTime } from '@/utils/format'
import DataTable from '@/components/common/DataTable.vue'
import EmptyState from '@/components/common/EmptyState.vue'
import Icon from '@/components/icons/Icon.vue'
import type { UsageLog } from '@/types'
import type { AdminUsageLog } from '@/types'
defineProps(['data', 'loading'])
const { t } = useI18n()
@@ -247,12 +247,12 @@ const { t } = useI18n()
// Tooltip state - cost
const tooltipVisible = ref(false)
const tooltipPosition = ref({ x: 0, y: 0 })
const tooltipData = ref<UsageLog | null>(null)
const tooltipData = ref<AdminUsageLog | null>(null)
// Tooltip state - token
const tokenTooltipVisible = ref(false)
const tokenTooltipPosition = ref({ x: 0, y: 0 })
const tokenTooltipData = ref<UsageLog | null>(null)
const tokenTooltipData = ref<AdminUsageLog | null>(null)
const cols = computed(() => [
{ key: 'user', label: t('admin.usage.user'), sortable: false },
@@ -296,7 +296,7 @@ const formatDuration = (ms: number | null | undefined): string => {
}
// Cost tooltip functions
const showTooltip = (event: MouseEvent, row: UsageLog) => {
const showTooltip = (event: MouseEvent, row: AdminUsageLog) => {
const target = event.currentTarget as HTMLElement
const rect = target.getBoundingClientRect()
tooltipData.value = row
@@ -311,7 +311,7 @@ const hideTooltip = () => {
}
// Token tooltip functions
const showTokenTooltip = (event: MouseEvent, row: UsageLog) => {
const showTokenTooltip = (event: MouseEvent, row: AdminUsageLog) => {
const target = event.currentTarget as HTMLElement
const rect = target.getBoundingClientRect()
tokenTooltipData.value = row

View File

@@ -39,10 +39,10 @@ import { ref, watch } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin'
import type { User, Group } from '@/types'
import type { AdminUser, Group } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
const props = defineProps<{ show: boolean, user: User | null }>()
const props = defineProps<{ show: boolean, user: AdminUser | null }>()
const emit = defineEmits(['close', 'success']); const { t } = useI18n(); const appStore = useAppStore()
const groups = ref<Group[]>([]); const selectedIds = ref<number[]>([]); const loading = ref(false); const submitting = ref(false)
@@ -56,4 +56,4 @@ const handleSave = async () => {
appStore.showSuccess(t('admin.users.allowedGroupsUpdated')); emit('success'); emit('close')
} catch (error) { console.error('Failed to update allowed groups:', error) } finally { submitting.value = false }
}
</script>
</script>

View File

@@ -32,10 +32,10 @@ import { ref, watch } from 'vue'
import { useI18n } from 'vue-i18n'
import { adminAPI } from '@/api/admin'
import { formatDateTime } from '@/utils/format'
import type { User, ApiKey } from '@/types'
import type { AdminUser, ApiKey } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
const props = defineProps<{ show: boolean, user: User | null }>()
const props = defineProps<{ show: boolean, user: AdminUser | null }>()
defineEmits(['close']); const { t } = useI18n()
const apiKeys = ref<ApiKey[]>([]); const loading = ref(false)
@@ -44,4 +44,4 @@ const load = async () => {
if (!props.user) return; loading.value = true
try { const res = await adminAPI.users.getUserApiKeys(props.user.id); apiKeys.value = res.items || [] } catch (error) { console.error('Failed to load API keys:', error) } finally { loading.value = false }
}
</script>
</script>

View File

@@ -29,10 +29,10 @@ import { reactive, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin'
import type { User } from '@/types'
import type { AdminUser } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
const props = defineProps<{ show: boolean, user: User | null, operation: 'add' | 'subtract' }>()
const props = defineProps<{ show: boolean, user: AdminUser | null, operation: 'add' | 'subtract' }>()
const emit = defineEmits(['close', 'success']); const { t } = useI18n(); const appStore = useAppStore()
const submitting = ref(false); const form = reactive({ amount: 0, notes: '' })

View File

@@ -56,12 +56,12 @@ import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { useClipboard } from '@/composables/useClipboard'
import { adminAPI } from '@/api/admin'
import type { User, UserAttributeValuesMap } from '@/types'
import type { AdminUser, UserAttributeValuesMap } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import UserAttributeForm from '@/components/user/UserAttributeForm.vue'
import Icon from '@/components/icons/Icon.vue'
const props = defineProps<{ show: boolean, user: User | null }>()
const props = defineProps<{ show: boolean, user: AdminUser | null }>()
const emit = defineEmits(['close', 'success'])
const { t } = useI18n(); const appStore = useAppStore(); const { copyToClipboard } = useClipboard()

View File

@@ -42,13 +42,13 @@
import { computed } from 'vue'
import { useI18n } from 'vue-i18n'
import GroupBadge from './GroupBadge.vue'
import type { Group, GroupPlatform } from '@/types'
import type { AdminGroup, GroupPlatform } from '@/types'
const { t } = useI18n()
interface Props {
modelValue: number[]
groups: Group[]
groups: AdminGroup[]
platform?: GroupPlatform // Optional platform filter
mixedScheduling?: boolean // For antigravity accounts: allow anthropic/gemini groups
}

View File

@@ -950,7 +950,7 @@ export default {
title: 'Subscription Management',
description: 'Manage user subscriptions and quota limits',
assignSubscription: 'Assign Subscription',
extendSubscription: 'Extend Subscription',
adjustSubscription: 'Adjust Subscription',
revokeSubscription: 'Revoke Subscription',
allStatus: 'All Status',
allGroups: 'All Groups',
@@ -965,6 +965,7 @@ export default {
resetInHoursMinutes: 'Resets in {hours}h {minutes}m',
resetInDaysHours: 'Resets in {days}d {hours}h',
daysRemaining: 'days remaining',
remainingDays: 'Remaining days',
noExpiration: 'No expiration',
status: {
active: 'Active',
@@ -983,28 +984,32 @@ export default {
user: 'User',
group: 'Subscription Group',
validityDays: 'Validity (Days)',
extendDays: 'Extend by (Days)'
adjustDays: 'Adjust by (Days)'
},
selectUser: 'Select a user',
selectGroup: 'Select a subscription group',
groupHint: 'Only groups with subscription billing type are shown',
validityHint: 'Number of days the subscription will be valid',
extendingFor: 'Extending subscription for',
adjustingFor: 'Adjusting subscription for',
currentExpiration: 'Current expiration',
adjustDaysPlaceholder: 'Positive to extend, negative to shorten',
adjustHint: 'Enter positive number to extend, negative to shorten (remaining days must be > 0)',
assign: 'Assign',
assigning: 'Assigning...',
extend: 'Extend',
extending: 'Extending...',
adjust: 'Adjust',
adjusting: 'Adjusting...',
revoke: 'Revoke',
noSubscriptionsYet: 'No subscriptions yet',
assignFirstSubscription: 'Assign a subscription to get started.',
subscriptionAssigned: 'Subscription assigned successfully',
subscriptionExtended: 'Subscription extended successfully',
subscriptionAdjusted: 'Subscription adjusted successfully',
subscriptionRevoked: 'Subscription revoked successfully',
failedToLoad: 'Failed to load subscriptions',
failedToAssign: 'Failed to assign subscription',
failedToExtend: 'Failed to extend subscription',
failedToAdjust: 'Failed to adjust subscription',
failedToRevoke: 'Failed to revoke subscription',
adjustWouldExpire: 'Remaining days after adjustment must be greater than 0',
adjustOutOfRange: 'Adjustment days must be between -36500 and 36500',
pleaseSelectUser: 'Please select a user',
pleaseSelectGroup: 'Please select a group',
validityDaysRequired: 'Please enter a valid number of days (at least 1)',
@@ -2721,7 +2726,9 @@ export default {
enableRegistration: 'Enable Registration',
enableRegistrationHint: 'Allow new users to register',
emailVerification: 'Email Verification',
emailVerificationHint: 'Require email verification for new registrations'
emailVerificationHint: 'Require email verification for new registrations',
promoCode: 'Promo Code',
promoCodeHint: 'Allow users to use promo codes during registration'
},
turnstile: {
title: 'Cloudflare Turnstile',
@@ -2791,7 +2798,9 @@ export default {
homeContent: 'Home Page Content',
homeContentPlaceholder: 'Enter custom content for the home page. Supports Markdown & HTML. If a URL is entered, it will be displayed as an iframe.',
homeContentHint: 'Customize the home page content. Supports Markdown/HTML. If you enter a URL (starting with http:// or https://), it will be used as an iframe src to embed an external page. When set, the default status information will no longer be displayed.',
homeContentIframeWarning: '⚠️ iframe mode note: Some websites have X-Frame-Options or CSP security policies that prevent embedding in iframes. If the page appears blank or shows an error, please verify the target website allows embedding, or consider using HTML mode to build your own content.'
homeContentIframeWarning: '⚠️ iframe mode note: Some websites have X-Frame-Options or CSP security policies that prevent embedding in iframes. If the page appears blank or shows an error, please verify the target website allows embedding, or consider using HTML mode to build your own content.',
hideCcsImportButton: 'Hide CCS Import Button',
hideCcsImportButtonHint: 'When enabled, the "Import to CCS" button will be hidden on the API Keys page'
},
smtp: {
title: 'SMTP Settings',

View File

@@ -1025,7 +1025,7 @@ export default {
title: '订阅管理',
description: '管理用户订阅和配额限制',
assignSubscription: '分配订阅',
extendSubscription: '延长订阅',
adjustSubscription: '调整订阅',
revokeSubscription: '撤销订阅',
allStatus: '全部状态',
allGroups: '全部分组',
@@ -1040,6 +1040,7 @@ export default {
resetInHoursMinutes: '{hours} 小时 {minutes} 分钟后重置',
resetInDaysHours: '{days} 天 {hours} 小时后重置',
daysRemaining: '天剩余',
remainingDays: '剩余天数',
noExpiration: '无过期时间',
status: {
active: '生效中',
@@ -1058,28 +1059,32 @@ export default {
user: '用户',
group: '订阅分组',
validityDays: '有效期(天)',
extendDays: '延长天数'
adjustDays: '调整天数'
},
selectUser: '选择用户',
selectGroup: '选择订阅分组',
groupHint: '仅显示订阅计费类型的分组',
validityHint: '订阅的有效天数',
extendingFor: '为以下用户延长订阅',
adjustingFor: '为以下用户调整订阅',
currentExpiration: '当前到期时间',
adjustDaysPlaceholder: '正数延长,负数缩短',
adjustHint: '输入正数延长订阅负数缩短订阅缩短后剩余天数需大于0',
assign: '分配',
assigning: '分配中...',
extend: '延长',
extending: '延长中...',
adjust: '调整',
adjusting: '调整中...',
revoke: '撤销',
noSubscriptionsYet: '暂无订阅',
assignFirstSubscription: '分配一个订阅以开始使用。',
subscriptionAssigned: '订阅分配成功',
subscriptionExtended: '订阅延长成功',
subscriptionAdjusted: '订阅调整成功',
subscriptionRevoked: '订阅撤销成功',
failedToLoad: '加载订阅列表失败',
failedToAssign: '分配订阅失败',
failedToExtend: '延长订阅失败',
failedToAdjust: '调整订阅失败',
failedToRevoke: '撤销订阅失败',
adjustWouldExpire: '调整后剩余天数必须大于0',
adjustOutOfRange: '调整天数必须在 -36500 到 36500 之间',
pleaseSelectUser: '请选择用户',
pleaseSelectGroup: '请选择分组',
validityDaysRequired: '请输入有效的天数至少1天',
@@ -2874,7 +2879,9 @@ export default {
enableRegistration: '开放注册',
enableRegistrationHint: '允许新用户注册',
emailVerification: '邮箱验证',
emailVerificationHint: '新用户注册时需要验证邮箱'
emailVerificationHint: '新用户注册时需要验证邮箱',
promoCode: '优惠码',
promoCodeHint: '允许用户在注册时使用优惠码'
},
turnstile: {
title: 'Cloudflare Turnstile',
@@ -2942,7 +2949,9 @@ export default {
homeContent: '首页内容',
homeContentPlaceholder: '在此输入首页内容,支持 Markdown & HTML 代码。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性。',
homeContentHint: '自定义首页内容,支持 Markdown/HTML。如果输入的是链接以 http:// 或 https:// 开头),则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为首页。设置后首页的状态信息将不再显示。',
homeContentIframeWarning: '⚠️ iframe 模式提示:部分网站设置了 X-Frame-Options 或 CSP 安全策略,禁止被嵌入到 iframe 中。如果页面显示空白或报错,请确认目标网站允许被嵌入,或考虑使用 HTML 模式自行构建页面内容。'
homeContentIframeWarning: '⚠️ iframe 模式提示:部分网站设置了 X-Frame-Options 或 CSP 安全策略,禁止被嵌入到 iframe 中。如果页面显示空白或报错,请确认目标网站允许被嵌入,或考虑使用 HTML 模式自行构建页面内容。',
hideCcsImportButton: '隐藏 CCS 导入按钮',
hideCcsImportButtonHint: '启用后将在 API Keys 页面隐藏"导入 CCS"按钮'
},
smtp: {
title: 'SMTP 设置',

View File

@@ -312,6 +312,7 @@ export const useAppStore = defineStore('app', () => {
return {
registration_enabled: false,
email_verify_enabled: false,
promo_code_enabled: true,
turnstile_enabled: false,
turnstile_site_key: '',
site_name: siteName.value,
@@ -321,6 +322,7 @@ export const useAppStore = defineStore('app', () => {
contact_info: contactInfo.value,
doc_url: docUrl.value,
home_content: '',
hide_ccs_import_button: false,
linuxdo_oauth_enabled: false,
version: siteVersion.value
}

View File

@@ -27,7 +27,6 @@ export interface FetchOptions {
export interface User {
id: number
username: string
notes: string
email: string
role: 'admin' | 'user' // User role for authorization
balance: number // User balance for API usage
@@ -39,6 +38,11 @@ export interface User {
updated_at: string
}
export interface AdminUser extends User {
// 管理员备注(普通用户接口不返回)
notes: string
}
export interface LoginRequest {
email: string
password: string
@@ -66,6 +70,7 @@ export interface SendVerifyCodeResponse {
export interface PublicSettings {
registration_enabled: boolean
email_verify_enabled: boolean
promo_code_enabled: boolean
turnstile_enabled: boolean
turnstile_site_key: string
site_name: string
@@ -75,6 +80,7 @@ export interface PublicSettings {
contact_info: string
doc_url: string
home_content: string
hide_ccs_import_button: boolean
linuxdo_oauth_enabled: boolean
version: string
}
@@ -269,14 +275,19 @@ export interface Group {
// Claude Code 客户端限制
claude_code_only: boolean
fallback_group_id: number | null
// 模型路由配置(仅 anthropic 平台使用)
model_routing: Record<string, number[]> | null
model_routing_enabled: boolean
account_count?: number
created_at: string
updated_at: string
}
export interface AdminGroup extends Group {
// 模型路由配置(仅管理员可见,内部信息)
model_routing: Record<string, number[]> | null
model_routing_enabled: boolean
// 分组下账号数量(仅管理员可见)
account_count?: number
}
export interface ApiKey {
id: number
user_id: number
@@ -636,7 +647,6 @@ export interface UsageLog {
total_cost: number
actual_cost: number
rate_multiplier: number
account_rate_multiplier?: number | null
billing_type: number
stream: boolean
@@ -650,18 +660,30 @@ export interface UsageLog {
// User-Agent
user_agent: string | null
// IP 地址(仅管理员可见)
ip_address: string | null
created_at: string
user?: User
api_key?: ApiKey
account?: Account
group?: Group
subscription?: UserSubscription
}
export interface UsageLogAccountSummary {
id: number
name: string
}
export interface AdminUsageLog extends UsageLog {
// 账号计费倍率(仅管理员可见)
account_rate_multiplier?: number | null
// 用户请求 IP仅管理员可见
ip_address?: string | null
// 最小账号信息(仅管理员接口返回)
account?: UsageLogAccountSummary
}
export interface UsageCleanupFilters {
start_time: string
end_time: string

View File

@@ -16,7 +16,7 @@
@sync="showSync = true"
@create="showCreate = true"
>
<template #before>
<template #after>
<!-- Column Settings Dropdown -->
<div class="relative" ref="columnDropdownRef">
<button
@@ -187,14 +187,14 @@ import AccountCapacityCell from '@/components/account/AccountCapacityCell.vue'
import PlatformTypeBadge from '@/components/common/PlatformTypeBadge.vue'
import Icon from '@/components/icons/Icon.vue'
import { formatDateTime, formatRelativeTime } from '@/utils/format'
import type { Account, Proxy, Group } from '@/types'
import type { Account, Proxy, AdminGroup } from '@/types'
const { t } = useI18n()
const appStore = useAppStore()
const authStore = useAuthStore()
const proxies = ref<Proxy[]>([])
const groups = ref<Group[]>([])
const groups = ref<AdminGroup[]>([])
const selIds = ref<number[]>([])
const showCreate = ref(false)
const showEdit = ref(false)

View File

@@ -1107,7 +1107,7 @@ import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { useOnboardingStore } from '@/stores/onboarding'
import { adminAPI } from '@/api/admin'
import type { Group, GroupPlatform, SubscriptionType } from '@/types'
import type { AdminGroup, GroupPlatform, SubscriptionType } from '@/types'
import type { Column } from '@/components/common/types'
import AppLayout from '@/components/layout/AppLayout.vue'
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
@@ -1202,7 +1202,7 @@ const fallbackGroupOptionsForEdit = computed(() => {
return options
})
const groups = ref<Group[]>([])
const groups = ref<AdminGroup[]>([])
const loading = ref(false)
const searchQuery = ref('')
const filters = reactive({
@@ -1223,8 +1223,8 @@ const showCreateModal = ref(false)
const showEditModal = ref(false)
const showDeleteDialog = ref(false)
const submitting = ref(false)
const editingGroup = ref<Group | null>(null)
const deletingGroup = ref<Group | null>(null)
const editingGroup = ref<AdminGroup | null>(null)
const deletingGroup = ref<AdminGroup | null>(null)
const createForm = reactive({
name: '',
@@ -1529,7 +1529,7 @@ const handleCreateGroup = async () => {
}
}
const handleEdit = async (group: Group) => {
const handleEdit = async (group: AdminGroup) => {
editingGroup.value = group
editForm.name = group.name
editForm.description = group.description || ''
@@ -1585,7 +1585,7 @@ const handleUpdateGroup = async () => {
}
}
const handleDelete = (group: Group) => {
const handleDelete = (group: AdminGroup) => {
deletingGroup.value = group
showDeleteDialog.value = true
}

View File

@@ -238,7 +238,30 @@
v-model="generateForm.group_id"
:options="subscriptionGroupOptions"
:placeholder="t('admin.redeem.selectGroupPlaceholder')"
/>
>
<template #selected="{ option }">
<GroupBadge
v-if="option"
:name="(option as unknown as GroupOption).label"
:platform="(option as unknown as GroupOption).platform"
:subscription-type="(option as unknown as GroupOption).subscriptionType"
:rate-multiplier="(option as unknown as GroupOption).rate"
/>
<span v-else class="text-gray-400">{{
t('admin.redeem.selectGroupPlaceholder')
}}</span>
</template>
<template #option="{ option, selected }">
<GroupOptionItem
:name="(option as unknown as GroupOption).label"
:platform="(option as unknown as GroupOption).platform"
:subscription-type="(option as unknown as GroupOption).subscriptionType"
:rate-multiplier="(option as unknown as GroupOption).rate"
:description="(option as unknown as GroupOption).description"
:selected="selected"
/>
</template>
</Select>
</div>
<div>
<label class="input-label">{{ t('admin.redeem.validityDays') }}</label>
@@ -370,7 +393,7 @@ import { useAppStore } from '@/stores/app'
import { useClipboard } from '@/composables/useClipboard'
import { adminAPI } from '@/api/admin'
import { formatDateTime } from '@/utils/format'
import type { RedeemCode, RedeemCodeType, Group } from '@/types'
import type { RedeemCode, RedeemCodeType, Group, GroupPlatform, SubscriptionType } from '@/types'
import type { Column } from '@/components/common/types'
import AppLayout from '@/components/layout/AppLayout.vue'
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
@@ -378,12 +401,23 @@ import DataTable from '@/components/common/DataTable.vue'
import Pagination from '@/components/common/Pagination.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue'
import GroupBadge from '@/components/common/GroupBadge.vue'
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import Icon from '@/components/icons/Icon.vue'
const { t } = useI18n()
const appStore = useAppStore()
const { copyToClipboard: clipboardCopy } = useClipboard()
interface GroupOption {
value: number
label: string
description: string | null
platform: GroupPlatform
subscriptionType: SubscriptionType
rate: number
}
const showGenerateDialog = ref(false)
const showResultDialog = ref(false)
const generatedCodes = ref<RedeemCode[]>([])
@@ -395,7 +429,11 @@ const subscriptionGroupOptions = computed(() => {
.filter((g) => g.subscription_type === 'subscription')
.map((g) => ({
value: g.id,
label: g.name
label: g.name,
description: g.description,
platform: g.platform,
subscriptionType: g.subscription_type,
rate: g.rate_multiplier
}))
})

View File

@@ -323,6 +323,21 @@
</div>
<Toggle v-model="form.email_verify_enabled" />
</div>
<!-- Promo Code -->
<div
class="flex items-center justify-between border-t border-gray-100 pt-4 dark:border-dark-700"
>
<div>
<label class="font-medium text-gray-900 dark:text-white">{{
t('admin.settings.registration.promoCode')
}}</label>
<p class="text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.settings.registration.promoCodeHint') }}
</p>
</div>
<Toggle v-model="form.promo_code_enabled" />
</div>
</div>
</div>
@@ -720,6 +735,21 @@
{{ t('admin.settings.site.homeContentIframeWarning') }}
</p>
</div>
<!-- Hide CCS Import Button -->
<div
class="flex items-center justify-between border-t border-gray-100 pt-4 dark:border-dark-700"
>
<div>
<label class="font-medium text-gray-900 dark:text-white">{{
t('admin.settings.site.hideCcsImportButton')
}}</label>
<p class="text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.settings.site.hideCcsImportButtonHint') }}
</p>
</div>
<Toggle v-model="form.hide_ccs_import_button" />
</div>
</div>
</div>
@@ -998,6 +1028,7 @@ type SettingsForm = SystemSettings & {
const form = reactive<SettingsForm>({
registration_enabled: true,
email_verify_enabled: false,
promo_code_enabled: true,
default_balance: 0,
default_concurrency: 1,
site_name: 'Sub2API',
@@ -1007,6 +1038,7 @@ const form = reactive<SettingsForm>({
contact_info: '',
doc_url: '',
home_content: '',
hide_ccs_import_button: false,
smtp_host: '',
smtp_port: 587,
smtp_username: '',
@@ -1119,6 +1151,7 @@ async function saveSettings() {
const payload: UpdateSettingsRequest = {
registration_enabled: form.registration_enabled,
email_verify_enabled: form.email_verify_enabled,
promo_code_enabled: form.promo_code_enabled,
default_balance: form.default_balance,
default_concurrency: form.default_concurrency,
site_name: form.site_name,
@@ -1128,6 +1161,7 @@ async function saveSettings() {
contact_info: form.contact_info,
doc_url: form.doc_url,
home_content: form.home_content,
hide_ccs_import_button: form.hide_ccs_import_button,
smtp_host: form.smtp_host,
smtp_port: form.smtp_port,
smtp_username: form.smtp_username,

View File

@@ -85,6 +85,14 @@
<!-- Right: Actions -->
<div class="ml-auto flex flex-wrap items-center justify-end gap-3">
<button
@click="loadSubscriptions"
:disabled="loading"
class="btn btn-secondary"
:title="t('common.refresh')"
>
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
</button>
<!-- Column Settings Dropdown -->
<div class="relative" ref="columnDropdownRef">
<button
@@ -136,14 +144,6 @@
</div>
</div>
</div>
<button
@click="loadSubscriptions"
:disabled="loading"
class="btn btn-secondary"
:title="t('common.refresh')"
>
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
</button>
<button @click="showAssignModal = true" class="btn btn-primary">
<Icon name="plus" size="md" class="mr-2" />
{{ t('admin.subscriptions.assignSubscription') }}
@@ -359,10 +359,10 @@
<button
v-if="row.status === 'active'"
@click="handleExtend(row)"
class="flex flex-col items-center gap-0.5 rounded-lg p-1.5 text-gray-500 transition-colors hover:bg-green-50 hover:text-green-600 dark:hover:bg-green-900/20 dark:hover:text-green-400"
class="flex flex-col items-center gap-0.5 rounded-lg p-1.5 text-gray-500 transition-colors hover:bg-blue-50 hover:text-blue-600 dark:hover:bg-blue-900/20 dark:hover:text-blue-400"
>
<Icon name="clock" size="sm" />
<span class="text-xs">{{ t('admin.subscriptions.extend') }}</span>
<Icon name="calendar" size="sm" />
<span class="text-xs">{{ t('admin.subscriptions.adjust') }}</span>
</button>
<button
v-if="row.status === 'active'"
@@ -466,7 +466,28 @@
v-model="assignForm.group_id"
:options="subscriptionGroupOptions"
:placeholder="t('admin.subscriptions.selectGroup')"
/>
>
<template #selected="{ option }">
<GroupBadge
v-if="option"
:name="(option as unknown as GroupOption).label"
:platform="(option as unknown as GroupOption).platform"
:subscription-type="(option as unknown as GroupOption).subscriptionType"
:rate-multiplier="(option as unknown as GroupOption).rate"
/>
<span v-else class="text-gray-400">{{ t('admin.subscriptions.selectGroup') }}</span>
</template>
<template #option="{ option, selected }">
<GroupOptionItem
:name="(option as unknown as GroupOption).label"
:platform="(option as unknown as GroupOption).platform"
:subscription-type="(option as unknown as GroupOption).subscriptionType"
:rate-multiplier="(option as unknown as GroupOption).rate"
:description="(option as unknown as GroupOption).description"
:selected="selected"
/>
</template>
</Select>
<p class="input-hint">{{ t('admin.subscriptions.groupHint') }}</p>
</div>
<div>
@@ -512,10 +533,10 @@
</template>
</BaseDialog>
<!-- Extend Subscription Modal -->
<!-- Adjust Subscription Modal -->
<BaseDialog
:show="showExtendModal"
:title="t('admin.subscriptions.extendSubscription')"
:title="t('admin.subscriptions.adjustSubscription')"
width="narrow"
@close="closeExtendModal"
>
@@ -527,7 +548,7 @@
>
<div class="rounded-lg bg-gray-50 p-4 dark:bg-dark-700">
<p class="text-sm text-gray-600 dark:text-gray-400">
{{ t('admin.subscriptions.extendingFor') }}
{{ t('admin.subscriptions.adjustingFor') }}
<span class="font-medium text-gray-900 dark:text-white">{{
extendingSubscription.user?.email
}}</span>
@@ -542,10 +563,25 @@
}}
</span>
</p>
<p v-if="extendingSubscription.expires_at" class="mt-1 text-sm text-gray-600 dark:text-gray-400">
{{ t('admin.subscriptions.remainingDays') }}:
<span class="font-medium text-gray-900 dark:text-white">
{{ getDaysRemaining(extendingSubscription.expires_at) ?? 0 }}
</span>
</p>
</div>
<div>
<label class="input-label">{{ t('admin.subscriptions.form.extendDays') }}</label>
<input v-model.number="extendForm.days" type="number" min="1" required class="input" />
<label class="input-label">{{ t('admin.subscriptions.form.adjustDays') }}</label>
<div class="flex items-center gap-2">
<input
v-model.number="extendForm.days"
type="number"
required
class="input text-center"
:placeholder="t('admin.subscriptions.adjustDaysPlaceholder')"
/>
</div>
<p class="input-hint">{{ t('admin.subscriptions.adjustHint') }}</p>
</div>
</form>
<template #footer>
@@ -559,7 +595,7 @@
:disabled="submitting"
class="btn btn-primary"
>
{{ submitting ? t('admin.subscriptions.extending') : t('admin.subscriptions.extend') }}
{{ submitting ? t('admin.subscriptions.adjusting') : t('admin.subscriptions.adjust') }}
</button>
</div>
</template>
@@ -584,7 +620,7 @@ import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin'
import type { UserSubscription, Group } from '@/types'
import type { UserSubscription, Group, GroupPlatform, SubscriptionType } from '@/types'
import type { SimpleUser } from '@/api/admin/usage'
import type { Column } from '@/components/common/types'
import { formatDateOnly } from '@/utils/format'
@@ -597,11 +633,21 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import EmptyState from '@/components/common/EmptyState.vue'
import Select from '@/components/common/Select.vue'
import GroupBadge from '@/components/common/GroupBadge.vue'
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import Icon from '@/components/icons/Icon.vue'
const { t } = useI18n()
const appStore = useAppStore()
interface GroupOption {
value: number
label: string
description: string | null
platform: GroupPlatform
subscriptionType: SubscriptionType
rate: number
}
// User column display mode: 'email' or 'username'
const userColumnMode = ref<'email' | 'username'>('email')
const USER_COLUMN_MODE_KEY = 'subscription-user-column-mode'
@@ -777,7 +823,14 @@ const groupOptions = computed(() => [
const subscriptionGroupOptions = computed(() =>
groups.value
.filter((g) => g.subscription_type === 'subscription' && g.status === 'active')
.map((g) => ({ value: g.id, label: g.name }))
.map((g) => ({
value: g.id,
label: g.name,
description: g.description,
platform: g.platform,
subscriptionType: g.subscription_type,
rate: g.rate_multiplier
}))
)
const applyFilters = () => {
@@ -1000,17 +1053,27 @@ const closeExtendModal = () => {
const handleExtendSubscription = async () => {
if (!extendingSubscription.value) return
// 前端验证:调整后剩余天数必须 > 0
if (extendingSubscription.value.expires_at) {
const currentDaysRemaining = getDaysRemaining(extendingSubscription.value.expires_at) ?? 0
const newDaysRemaining = currentDaysRemaining + extendForm.days
if (newDaysRemaining <= 0) {
appStore.showError(t('admin.subscriptions.adjustWouldExpire'))
return
}
}
submitting.value = true
try {
await adminAPI.subscriptions.extend(extendingSubscription.value.id, {
days: extendForm.days
})
appStore.showSuccess(t('admin.subscriptions.subscriptionExtended'))
appStore.showSuccess(t('admin.subscriptions.subscriptionAdjusted'))
closeExtendModal()
loadSubscriptions()
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.subscriptions.failedToExtend'))
console.error('Error extending subscription:', error)
appStore.showError(error.response?.data?.detail || t('admin.subscriptions.failedToAdjust'))
console.error('Error adjusting subscription:', error)
} finally {
submitting.value = false
}

View File

@@ -42,11 +42,11 @@ import UsageStatsCards from '@/components/admin/usage/UsageStatsCards.vue'; impo
import UsageTable from '@/components/admin/usage/UsageTable.vue'; import UsageExportProgress from '@/components/admin/usage/UsageExportProgress.vue'
import UsageCleanupDialog from '@/components/admin/usage/UsageCleanupDialog.vue'
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'; import TokenUsageTrend from '@/components/charts/TokenUsageTrend.vue'
import type { UsageLog, TrendDataPoint, ModelStat } from '@/types'; import type { AdminUsageStatsResponse, AdminUsageQueryParams } from '@/api/admin/usage'
import type { AdminUsageLog, TrendDataPoint, ModelStat } from '@/types'; import type { AdminUsageStatsResponse, AdminUsageQueryParams } from '@/api/admin/usage'
const { t } = useI18n()
const appStore = useAppStore()
const usageStats = ref<AdminUsageStatsResponse | null>(null); const usageLogs = ref<UsageLog[]>([]); const loading = ref(false); const exporting = ref(false)
const usageStats = ref<AdminUsageStatsResponse | null>(null); const usageLogs = ref<AdminUsageLog[]>([]); const loading = ref(false); const exporting = ref(false)
const trendData = ref<TrendDataPoint[]>([]); const modelStats = ref<ModelStat[]>([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day')
let abortController: AbortController | null = null; let exportAbortController: AbortController | null = null
const exportProgress = reactive({ show: false, progress: 0, current: 0, total: 0, estimatedTime: '' })
@@ -92,7 +92,7 @@ const exportToExcel = async () => {
if (exporting.value) return; exporting.value = true; exportProgress.show = true
const c = new AbortController(); exportAbortController = c
try {
const all: UsageLog[] = []; let p = 1; let total = pagination.total
const all: AdminUsageLog[] = []; let p = 1; let total = pagination.total
while (true) {
const res = await adminUsageAPI.list({ page: p, page_size: 100, ...filters.value }, { signal: c.signal })
if (c.signal.aborted) break; if (p === 1) { total = res.total; exportProgress.total = total }

View File

@@ -492,7 +492,7 @@ import Icon from '@/components/icons/Icon.vue'
const { t } = useI18n()
import { adminAPI } from '@/api/admin'
import type { User, UserAttributeDefinition } from '@/types'
import type { AdminUser, UserAttributeDefinition } from '@/types'
import type { BatchUserUsageStats } from '@/api/admin/dashboard'
import type { Column } from '@/components/common/types'
import AppLayout from '@/components/layout/AppLayout.vue'
@@ -637,7 +637,7 @@ const columns = computed<Column[]>(() =>
)
)
const users = ref<User[]>([])
const users = ref<AdminUser[]>([])
const loading = ref(false)
const searchQuery = ref('')
@@ -736,16 +736,16 @@ const showEditModal = ref(false)
const showDeleteDialog = ref(false)
const showApiKeysModal = ref(false)
const showAttributesModal = ref(false)
const editingUser = ref<User | null>(null)
const deletingUser = ref<User | null>(null)
const viewingUser = ref<User | null>(null)
const editingUser = ref<AdminUser | null>(null)
const deletingUser = ref<AdminUser | null>(null)
const viewingUser = ref<AdminUser | null>(null)
let abortController: AbortController | null = null
// Action Menu State
const activeMenuId = ref<number | null>(null)
const menuPosition = ref<{ top: number; left: number } | null>(null)
const openActionMenu = (user: User, e: MouseEvent) => {
const openActionMenu = (user: AdminUser, e: MouseEvent) => {
if (activeMenuId.value === user.id) {
closeActionMenu()
} else {
@@ -821,11 +821,11 @@ const handleClickOutside = (event: MouseEvent) => {
// Allowed groups modal state
const showAllowedGroupsModal = ref(false)
const allowedGroupsUser = ref<User | null>(null)
const allowedGroupsUser = ref<AdminUser | null>(null)
// Balance (Deposit/Withdraw) modal state
const showBalanceModal = ref(false)
const balanceUser = ref<User | null>(null)
const balanceUser = ref<AdminUser | null>(null)
const balanceOperation = ref<'add' | 'subtract'>('add')
// 计算剩余天数
@@ -998,7 +998,7 @@ const applyFilter = () => {
loadUsers()
}
const handleEdit = (user: User) => {
const handleEdit = (user: AdminUser) => {
editingUser.value = user
showEditModal.value = true
}
@@ -1008,7 +1008,7 @@ const closeEditModal = () => {
editingUser.value = null
}
const handleToggleStatus = async (user: User) => {
const handleToggleStatus = async (user: AdminUser) => {
const newStatus = user.status === 'active' ? 'disabled' : 'active'
try {
await adminAPI.users.toggleStatus(user.id, newStatus)
@@ -1022,7 +1022,7 @@ const handleToggleStatus = async (user: User) => {
}
}
const handleViewApiKeys = (user: User) => {
const handleViewApiKeys = (user: AdminUser) => {
viewingUser.value = user
showApiKeysModal.value = true
}
@@ -1032,7 +1032,7 @@ const closeApiKeysModal = () => {
viewingUser.value = null
}
const handleAllowedGroups = (user: User) => {
const handleAllowedGroups = (user: AdminUser) => {
allowedGroupsUser.value = user
showAllowedGroupsModal.value = true
}
@@ -1042,7 +1042,7 @@ const closeAllowedGroupsModal = () => {
allowedGroupsUser.value = null
}
const handleDelete = (user: User) => {
const handleDelete = (user: AdminUser) => {
deletingUser.value = user
showDeleteDialog.value = true
}
@@ -1061,13 +1061,13 @@ const confirmDelete = async () => {
}
}
const handleDeposit = (user: User) => {
const handleDeposit = (user: AdminUser) => {
balanceUser.value = user
balanceOperation.value = 'add'
showBalanceModal.value = true
}
const handleWithdraw = (user: User) => {
const handleWithdraw = (user: AdminUser) => {
balanceUser.value = user
balanceOperation.value = 'subtract'
showBalanceModal.value = true

View File

@@ -96,7 +96,7 @@
</div>
<!-- Promo Code Input (Optional) -->
<div>
<div v-if="promoCodeEnabled">
<label for="promo_code" class="input-label">
{{ t('auth.promoCodeLabel') }}
<span class="ml-1 text-xs font-normal text-gray-400 dark:text-dark-500">({{ t('common.optional') }})</span>
@@ -260,6 +260,7 @@ const showPassword = ref<boolean>(false)
// Public settings
const registrationEnabled = ref<boolean>(true)
const emailVerifyEnabled = ref<boolean>(false)
const promoCodeEnabled = ref<boolean>(true)
const turnstileEnabled = ref<boolean>(false)
const turnstileSiteKey = ref<string>('')
const siteName = ref<string>('Sub2API')
@@ -294,22 +295,25 @@ const errors = reactive({
// ==================== Lifecycle ====================
onMounted(async () => {
// Read promo code from URL parameter
const promoParam = route.query.promo as string
if (promoParam) {
formData.promo_code = promoParam
// Validate the promo code from URL
await validatePromoCodeDebounced(promoParam)
}
try {
const settings = await getPublicSettings()
registrationEnabled.value = settings.registration_enabled
emailVerifyEnabled.value = settings.email_verify_enabled
promoCodeEnabled.value = settings.promo_code_enabled
turnstileEnabled.value = settings.turnstile_enabled
turnstileSiteKey.value = settings.turnstile_site_key || ''
siteName.value = settings.site_name || 'Sub2API'
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled
// Read promo code from URL parameter only if promo code is enabled
if (promoCodeEnabled.value) {
const promoParam = route.query.promo as string
if (promoParam) {
formData.promo_code = promoParam
// Validate the promo code from URL
await validatePromoCodeDebounced(promoParam)
}
}
} catch (error) {
console.error('Failed to load public settings:', error)
} finally {

View File

@@ -133,6 +133,7 @@
</button>
<!-- Import to CC Switch Button -->
<button
v-if="!publicSettings?.hide_ccs_import_button"
@click="importToCcswitch(row)"
class="flex flex-col items-center gap-0.5 rounded-lg p-1.5 text-gray-500 transition-colors hover:bg-blue-50 hover:text-blue-600 dark:hover:bg-blue-900/20 dark:hover:text-blue-400"
>

View File

@@ -1,4 +1,4 @@
import { defineConfig, Plugin } from 'vite'
import { defineConfig, loadEnv, Plugin } from 'vite'
import vue from '@vitejs/plugin-vue'
import checker from 'vite-plugin-checker'
import { resolve } from 'path'
@@ -7,9 +7,7 @@ import { resolve } from 'path'
* Vite 插件:开发模式下注入公开配置到 index.html
* 与生产模式的后端注入行为保持一致,消除闪烁
*/
function injectPublicSettings(): Plugin {
const backendUrl = process.env.VITE_DEV_PROXY_TARGET || 'http://localhost:8080'
function injectPublicSettings(backendUrl: string): Plugin {
return {
name: 'inject-public-settings',
transformIndexHtml: {
@@ -35,15 +33,21 @@ function injectPublicSettings(): Plugin {
}
}
export default defineConfig({
plugins: [
vue(),
checker({
typescript: true,
vueTsc: true
}),
injectPublicSettings()
],
export default defineConfig(({ mode }) => {
// 加载环境变量
const env = loadEnv(mode, process.cwd(), '')
const backendUrl = env.VITE_DEV_PROXY_TARGET || 'http://localhost:8080'
const devPort = Number(env.VITE_DEV_PORT || 3000)
return {
plugins: [
vue(),
checker({
typescript: true,
vueTsc: true
}),
injectPublicSettings(backendUrl)
],
resolve: {
alias: {
'@': resolve(__dirname, 'src'),
@@ -102,17 +106,18 @@ export default defineConfig({
}
}
},
server: {
host: '0.0.0.0',
port: Number(process.env.VITE_DEV_PORT || 3000),
proxy: {
'/api': {
target: process.env.VITE_DEV_PROXY_TARGET || 'http://localhost:8080',
changeOrigin: true
},
'/setup': {
target: process.env.VITE_DEV_PROXY_TARGET || 'http://localhost:8080',
changeOrigin: true
server: {
host: '0.0.0.0',
port: devPort,
proxy: {
'/api': {
target: backendUrl,
changeOrigin: true
},
'/setup': {
target: backendUrl,
changeOrigin: true
}
}
}
}