Merge branch 'main' into test

This commit is contained in:
yangjianbo
2026-02-03 22:48:04 +08:00
235 changed files with 25155 additions and 7955 deletions

View File

@@ -47,6 +47,7 @@ type Config struct {
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
@@ -479,6 +480,8 @@ type RedisConfig struct {
PoolSize int `mapstructure:"pool_size"`
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
MinIdleConns int `mapstructure:"min_idle_conns"`
// EnableTLS: 是否启用 TLS/SSL 连接
EnableTLS bool `mapstructure:"enable_tls"`
}
func (r *RedisConfig) Address() string {
@@ -531,6 +534,16 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
// TotpConfig TOTP 双因素认证配置
type TotpConfig struct {
// EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥32 字节 hex 编码)
// 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
EncryptionKey string `mapstructure:"encryption_key"`
// EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
EncryptionKeyConfigured bool `mapstructure:"-"`
}
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
@@ -691,6 +704,20 @@ func Load() (*Config, error) {
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
if cfg.Totp.EncryptionKey == "" {
key, err := generateJWTSecret(32) // Reuse the same random generation function
if err != nil {
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
}
cfg.Totp.EncryptionKey = key
cfg.Totp.EncryptionKeyConfigured = false
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
} else {
cfg.Totp.EncryptionKeyConfigured = true
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
@@ -802,6 +829,7 @@ func setDefaults() {
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
viper.SetDefault("redis.enable_tls", false)
// Ops (vNext)
viper.SetDefault("ops.enabled", true)
@@ -821,6 +849,9 @@ func setDefaults() {
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
// TOTP
viper.SetDefault("totp.encryption_key", "")
// Default
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.

View File

@@ -0,0 +1,226 @@
package domain
import (
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
AnnouncementStatusDraft = "draft"
AnnouncementStatusActive = "active"
AnnouncementStatusArchived = "archived"
)
const (
AnnouncementConditionTypeSubscription = "subscription"
AnnouncementConditionTypeBalance = "balance"
)
const (
AnnouncementOperatorIn = "in"
AnnouncementOperatorGT = "gt"
AnnouncementOperatorGTE = "gte"
AnnouncementOperatorLT = "lt"
AnnouncementOperatorLTE = "lte"
AnnouncementOperatorEQ = "eq"
)
var (
ErrAnnouncementNotFound = infraerrors.NotFound("ANNOUNCEMENT_NOT_FOUND", "announcement not found")
ErrAnnouncementInvalidTarget = infraerrors.BadRequest("ANNOUNCEMENT_INVALID_TARGET", "invalid announcement targeting rules")
)
type AnnouncementTargeting struct {
// AnyOf 表示 OR任意一个条件组满足即可展示。
AnyOf []AnnouncementConditionGroup `json:"any_of,omitempty"`
}
type AnnouncementConditionGroup struct {
// AllOf 表示 AND组内所有条件都满足才算命中该组。
AllOf []AnnouncementCondition `json:"all_of,omitempty"`
}
type AnnouncementCondition struct {
// Type: subscription | balance
Type string `json:"type"`
// Operator:
// - subscription: in
// - balance: gt/gte/lt/lte/eq
Operator string `json:"operator"`
// subscription 条件匹配的订阅套餐group_id
GroupIDs []int64 `json:"group_ids,omitempty"`
// balance 条件:比较阈值
Value float64 `json:"value,omitempty"`
}
func (t AnnouncementTargeting) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
// 空规则:展示给所有用户
if len(t.AnyOf) == 0 {
return true
}
for _, group := range t.AnyOf {
if len(group.AllOf) == 0 {
// 空条件组不命中(避免 OR 中出现无条件 “全命中”)
continue
}
allMatched := true
for _, cond := range group.AllOf {
if !cond.Matches(balance, activeSubscriptionGroupIDs) {
allMatched = false
break
}
}
if allMatched {
return true
}
}
return false
}
func (c AnnouncementCondition) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
switch c.Type {
case AnnouncementConditionTypeSubscription:
if c.Operator != AnnouncementOperatorIn {
return false
}
if len(c.GroupIDs) == 0 {
return false
}
if len(activeSubscriptionGroupIDs) == 0 {
return false
}
for _, gid := range c.GroupIDs {
if _, ok := activeSubscriptionGroupIDs[gid]; ok {
return true
}
}
return false
case AnnouncementConditionTypeBalance:
switch c.Operator {
case AnnouncementOperatorGT:
return balance > c.Value
case AnnouncementOperatorGTE:
return balance >= c.Value
case AnnouncementOperatorLT:
return balance < c.Value
case AnnouncementOperatorLTE:
return balance <= c.Value
case AnnouncementOperatorEQ:
return balance == c.Value
default:
return false
}
default:
return false
}
}
func (t AnnouncementTargeting) NormalizeAndValidate() (AnnouncementTargeting, error) {
normalized := AnnouncementTargeting{AnyOf: make([]AnnouncementConditionGroup, 0, len(t.AnyOf))}
// 允许空 targeting展示给所有用户
if len(t.AnyOf) == 0 {
return normalized, nil
}
if len(t.AnyOf) > 50 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
for _, g := range t.AnyOf {
if len(g.AllOf) == 0 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
if len(g.AllOf) > 50 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
group := AnnouncementConditionGroup{AllOf: make([]AnnouncementCondition, 0, len(g.AllOf))}
for _, c := range g.AllOf {
cond := AnnouncementCondition{
Type: strings.TrimSpace(c.Type),
Operator: strings.TrimSpace(c.Operator),
Value: c.Value,
}
for _, gid := range c.GroupIDs {
if gid <= 0 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
cond.GroupIDs = append(cond.GroupIDs, gid)
}
if err := cond.validate(); err != nil {
return AnnouncementTargeting{}, err
}
group.AllOf = append(group.AllOf, cond)
}
normalized.AnyOf = append(normalized.AnyOf, group)
}
return normalized, nil
}
func (c AnnouncementCondition) validate() error {
switch c.Type {
case AnnouncementConditionTypeSubscription:
if c.Operator != AnnouncementOperatorIn {
return ErrAnnouncementInvalidTarget
}
if len(c.GroupIDs) == 0 {
return ErrAnnouncementInvalidTarget
}
return nil
case AnnouncementConditionTypeBalance:
switch c.Operator {
case AnnouncementOperatorGT, AnnouncementOperatorGTE, AnnouncementOperatorLT, AnnouncementOperatorLTE, AnnouncementOperatorEQ:
return nil
default:
return ErrAnnouncementInvalidTarget
}
default:
return ErrAnnouncementInvalidTarget
}
}
type Announcement struct {
ID int64
Title string
Content string
Status string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
CreatedBy *int64
UpdatedBy *int64
CreatedAt time.Time
UpdatedAt time.Time
}
func (a *Announcement) IsActiveAt(now time.Time) bool {
if a == nil {
return false
}
if a.Status != AnnouncementStatusActive {
return false
}
if a.StartsAt != nil && now.Before(*a.StartsAt) {
return false
}
if a.EndsAt != nil && !now.Before(*a.EndsAt) {
// ends_at 语义:到点即下线
return false
}
return true
}

View File

@@ -0,0 +1,66 @@
package domain
// Status constants
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// Role constants
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// Platform constants
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
PlatformSora = "sora"
)
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
)
// Redeem type constants
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
RedeemTypeInvitation = "invitation"
)
// PromoCode status constants
const (
PromoCodeStatusActive = "active"
PromoCodeStatusDisabled = "disabled"
)
// Admin adjustment type constants
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)
// Group subscription type constants
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
// Subscription status constants
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)

View File

@@ -547,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,更新凭证但不标记为 error
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
if tokenInfo.ProjectIDMissing {
// 先更新凭证
// 先更新凭证token 本身刷新成功了)
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
@@ -557,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
}
// 标记账户为 error
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id可能无法使用Antigravity"); setErr != nil {
response.InternalError(c, "Failed to set account error: "+setErr.Error())
return
}
// 标记为 error,只返回警告信息
response.Success(c, gin.H{
"message": "Token refreshed but project_id is missing, account marked as error",
"warning": "missing_project_id",
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
"warning": "missing_project_id_temporary",
})
return
}

View File

@@ -290,5 +290,9 @@ func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*ser
return &code, nil
}
func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]service.RedeemCode, int64, float64, error) {
return s.redeems, int64(len(s.redeems)), 100.0, nil
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -0,0 +1,246 @@
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AnnouncementHandler handles admin announcement management
type AnnouncementHandler struct {
announcementService *service.AnnouncementService
}
// NewAnnouncementHandler creates a new admin announcement handler
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
return &AnnouncementHandler{
announcementService: announcementService,
}
}
type CreateAnnouncementRequest struct {
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
Targeting service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
}
type UpdateAnnouncementRequest struct {
Title *string `json:"title"`
Content *string `json:"content"`
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
Targeting *service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
}
// List handles listing announcements with filters
// GET /api/v1/admin/announcements
func (h *AnnouncementHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := strings.TrimSpace(c.Query("status"))
search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 {
search = search[:200]
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
items, paginationResult, err := h.announcementService.List(
c.Request.Context(),
params,
service.AnnouncementListFilters{Status: status, Search: search},
)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.Announcement, 0, len(items))
for i := range items {
out = append(out, *dto.AnnouncementFromService(&items[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
// GetByID handles getting an announcement by ID
// GET /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) GetByID(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
item, err := h.announcementService.GetByID(c.Request.Context(), announcementID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(item))
}
// Create handles creating a new announcement
// POST /api/v1/admin/announcements
func (h *AnnouncementHandler) Create(c *gin.Context) {
var req CreateAnnouncementRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
input := &service.CreateAnnouncementInput{
Title: req.Title,
Content: req.Content,
Status: req.Status,
Targeting: req.Targeting,
ActorID: &subject.UserID,
}
if req.StartsAt != nil && *req.StartsAt > 0 {
t := time.Unix(*req.StartsAt, 0)
input.StartsAt = &t
}
if req.EndsAt != nil && *req.EndsAt > 0 {
t := time.Unix(*req.EndsAt, 0)
input.EndsAt = &t
}
created, err := h.announcementService.Create(c.Request.Context(), input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(created))
}
// Update handles updating an announcement
// PUT /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) Update(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
var req UpdateAnnouncementRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
input := &service.UpdateAnnouncementInput{
Title: req.Title,
Content: req.Content,
Status: req.Status,
Targeting: req.Targeting,
ActorID: &subject.UserID,
}
if req.StartsAt != nil {
if *req.StartsAt == 0 {
var cleared *time.Time = nil
input.StartsAt = &cleared
} else {
t := time.Unix(*req.StartsAt, 0)
ptr := &t
input.StartsAt = &ptr
}
}
if req.EndsAt != nil {
if *req.EndsAt == 0 {
var cleared *time.Time = nil
input.EndsAt = &cleared
} else {
t := time.Unix(*req.EndsAt, 0)
ptr := &t
input.EndsAt = &ptr
}
}
updated, err := h.announcementService.Update(c.Request.Context(), announcementID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(updated))
}
// Delete handles deleting an announcement
// DELETE /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) Delete(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
if err := h.announcementService.Delete(c.Request.Context(), announcementID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Announcement deleted successfully"})
}
// ListReadStatus handles listing users read status for an announcement
// GET /api/v1/admin/announcements/:id/read-status
func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 {
search = search[:200]
}
items, paginationResult, err := h.announcementService.ListUserReadStatus(
c.Request.Context(),
announcementID,
params,
search,
)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, paginationResult.Total, page, pageSize)
}

View File

@@ -47,6 +47,8 @@ type CreateGroupRequest struct {
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
// UpdateGroupRequest represents update group request
@@ -74,6 +76,8 @@ type UpdateGroupRequest struct {
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
// List handles listing all groups with pagination
@@ -183,6 +187,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -229,6 +234,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
response.ErrorFrom(c, err)

View File

@@ -29,7 +29,7 @@ func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
// GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"min=0"`
GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用默认30天最大100年

View File

@@ -48,6 +48,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
@@ -70,6 +74,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
@@ -89,9 +95,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -114,14 +123,16 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -198,6 +209,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// TOTP 双因素认证参数验证
// 只有手动配置了加密密钥才允许启用 TOTP 功能
if req.TotpEnabled && !previousSettings.TotpEnabled {
// 尝试启用 TOTP检查加密密钥是否已手动配置
if !h.settingService.IsTotpEncryptionKeyConfigured() {
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
return
}
}
// LinuxDo Connect 参数验证
if req.LinuxDoConnectEnabled {
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
@@ -227,6 +248,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// “购买订阅”页面配置验证
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
if req.PurchaseSubscriptionEnabled != nil {
purchaseEnabled = *req.PurchaseSubscriptionEnabled
}
purchaseURL := previousSettings.PurchaseSubscriptionURL
if req.PurchaseSubscriptionURL != nil {
purchaseURL = strings.TrimSpace(*req.PurchaseSubscriptionURL)
}
// - 启用时要求 URL 合法且非空
// - 禁用时允许为空;若提供了 URL 也做基本校验,避免误配置
if purchaseEnabled {
if purchaseURL == "" {
response.BadRequest(c, "Purchase Subscription URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
return
}
} else if purchaseURL != "" {
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
return
}
}
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
@@ -240,40 +289,45 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -318,6 +372,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
@@ -340,6 +398,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
DocURL: updatedSettings.DocURL,
HomeContent: updatedSettings.HomeContent,
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -384,6 +444,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
if before.TotpEnabled != after.TotpEnabled {
changed = append(changed, "totp_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}

View File

@@ -77,7 +77,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
}
status := c.Query("status")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
// Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -277,3 +277,44 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
response.Success(c, stats)
}
// GetBalanceHistory handles getting user's balance/concurrency change history
// GET /api/v1/admin/users/:id/balance-history
// Query params:
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
page, pageSize := response.ParsePagination(c)
codeType := c.Query("type")
codes, total, totalRecharged, err := h.adminService.GetUserBalanceHistory(c.Request.Context(), userID, page, pageSize, codeType)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Convert to admin DTO (includes notes field for admin visibility)
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
// Custom response with total_recharged alongside pagination
pages := int((total + int64(pageSize) - 1) / int64(pageSize))
if pages < 1 {
pages = 1
}
response.Success(c, gin.H{
"items": out,
"total": total,
"page": page,
"page_size": pageSize,
"pages": pages,
"total_recharged": totalRecharged,
})
}

View File

@@ -0,0 +1,81 @@
package handler
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AnnouncementHandler handles user announcement operations
type AnnouncementHandler struct {
announcementService *service.AnnouncementService
}
// NewAnnouncementHandler creates a new user announcement handler
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
return &AnnouncementHandler{
announcementService: announcementService,
}
}
// List handles listing announcements visible to current user
// GET /api/v1/announcements
func (h *AnnouncementHandler) List(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
unreadOnly := parseBoolQuery(c.Query("unread_only"))
items, err := h.announcementService.ListForUser(c.Request.Context(), subject.UserID, unreadOnly)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.UserAnnouncement, 0, len(items))
for i := range items {
out = append(out, *dto.UserAnnouncementFromService(&items[i]))
}
response.Success(c, out)
}
// MarkRead marks an announcement as read for current user
// POST /api/v1/announcements/:id/read
func (h *AnnouncementHandler) MarkRead(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
if err := h.announcementService.MarkRead(c.Request.Context(), subject.UserID, announcementID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "ok"})
}
func parseBoolQuery(v string) bool {
switch strings.TrimSpace(strings.ToLower(v)) {
case "1", "true", "yes", "y", "on":
return true
default:
return false
}
}

View File

@@ -1,6 +1,8 @@
package handler
import (
"log/slog"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
@@ -13,21 +15,25 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
redeemService *service.RedeemService
totpService *service.TotpService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
redeemService: redeemService,
totpService: totpService,
}
}
@@ -37,7 +43,8 @@ type RegisterRequest struct {
Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
PromoCode string `json:"promo_code"` // 注册优惠码
InvitationCode string `json:"invitation_code"` // 邀请码
}
// SendVerifyCodeRequest 发送验证码请求
@@ -83,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -144,6 +151,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
if err != nil {
response.InternalError(c, "Failed to create 2FA session")
return
}
response.Success(c, TotpLoginResponse{
Requires2FA: true,
TempToken: tempToken,
UserEmailMasked: service.MaskEmail(user.Email),
})
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
}
// TotpLoginResponse represents the response when 2FA is required
type TotpLoginResponse struct {
Requires2FA bool `json:"requires_2fa"`
TempToken string `json:"temp_token,omitempty"`
UserEmailMasked string `json:"user_email_masked,omitempty"`
}
// Login2FARequest represents the 2FA login request
type Login2FARequest struct {
TempToken string `json:"temp_token" binding:"required"`
TotpCode string `json:"totp_code" binding:"required,len=6"`
}
// Login2FA completes the login with 2FA verification
// POST /api/v1/auth/login/2fa
func (h *AuthHandler) Login2FA(c *gin.Context) {
var req Login2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
slog.Debug("login_2fa_request",
"temp_token_len", len(req.TempToken),
"totp_code_len", len(req.TotpCode))
// Get the login session
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
if err != nil || session == nil {
tokenPrefix := ""
if len(req.TempToken) >= 8 {
tokenPrefix = req.TempToken[:8]
}
slog.Debug("login_2fa_session_invalid",
"temp_token_prefix", tokenPrefix,
"error", err)
response.BadRequest(c, "Invalid or expired 2FA session")
return
}
slog.Debug("login_2fa_session_found",
"user_id", session.UserID,
"email", session.Email)
// Verify the TOTP code
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
slog.Debug("login_2fa_verify_failed",
"user_id", session.UserID,
"error", err)
response.ErrorFrom(c, err)
return
}
// Delete the login session
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
// Get the user
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Generate the JWT token
token, err := h.authService.GenerateToken(user)
if err != nil {
response.InternalError(c, "Failed to generate token")
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
@@ -247,3 +348,146 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
BonusAmount: promoCode.BonusAmount,
})
}
// ValidateInvitationCodeRequest 验证邀请码请求
type ValidateInvitationCodeRequest struct {
Code string `json:"code" binding:"required"`
}
// ValidateInvitationCodeResponse 验证邀请码响应
type ValidateInvitationCodeResponse struct {
Valid bool `json:"valid"`
ErrorCode string `json:"error_code,omitempty"`
}
// ValidateInvitationCode 验证邀请码(公开接口,注册前调用)
// POST /api/v1/auth/validate-invitation-code
func (h *AuthHandler) ValidateInvitationCode(c *gin.Context) {
// 检查邀请码功能是否启用
if h.settingSvc == nil || !h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_DISABLED",
})
return
}
var req ValidateInvitationCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 验证邀请码
redeemCode, err := h.redeemService.GetByCode(c.Request.Context(), req.Code)
if err != nil {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_NOT_FOUND",
})
return
}
// 检查类型和状态
if redeemCode.Type != service.RedeemTypeInvitation {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_INVALID",
})
return
}
if redeemCode.Status != service.StatusUnused {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_USED",
})
return
}
response.Success(c, ValidateInvitationCodeResponse{
Valid: true,
})
}
// ForgotPasswordRequest 忘记密码请求
type ForgotPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token"`
}
// ForgotPasswordResponse 忘记密码响应
type ForgotPasswordResponse struct {
Message string `json:"message"`
}
// ForgotPassword 请求密码重置
// POST /api/v1/auth/forgot-password
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
var req ForgotPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
// Build frontend base URL from request
scheme := "https"
if c.Request.TLS == nil {
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
}
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ForgotPasswordResponse{
Message: "If your email is registered, you will receive a password reset link shortly.",
})
}
// ResetPasswordRequest 重置密码请求
type ResetPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
Token string `json:"token" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// ResetPasswordResponse 重置密码响应
type ResetPasswordResponse struct {
Message string `json:"message"`
}
// ResetPassword 重置密码
// POST /api/v1/auth/reset-password
func (h *AuthHandler) ResetPassword(c *gin.Context) {
var req ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Reset password
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ResetPasswordResponse{
Message: "Your password has been reset successfully. You can now log in with your new password.",
})
}

View File

@@ -0,0 +1,74 @@
package dto
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type Announcement struct {
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Status string `json:"status"`
Targeting service.AnnouncementTargeting `json:"targeting"`
StartsAt *time.Time `json:"starts_at,omitempty"`
EndsAt *time.Time `json:"ends_at,omitempty"`
CreatedBy *int64 `json:"created_by,omitempty"`
UpdatedBy *int64 `json:"updated_by,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type UserAnnouncement struct {
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
StartsAt *time.Time `json:"starts_at,omitempty"`
EndsAt *time.Time `json:"ends_at,omitempty"`
ReadAt *time.Time `json:"read_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func AnnouncementFromService(a *service.Announcement) *Announcement {
if a == nil {
return nil
}
return &Announcement{
ID: a.ID,
Title: a.Title,
Content: a.Content,
Status: a.Status,
Targeting: a.Targeting,
StartsAt: a.StartsAt,
EndsAt: a.EndsAt,
CreatedBy: a.CreatedBy,
UpdatedBy: a.UpdatedBy,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
}
}
func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement {
if a == nil {
return nil
}
return &UserAnnouncement{
ID: a.Announcement.ID,
Title: a.Announcement.Title,
Content: a.Announcement.Content,
StartsAt: a.Announcement.StartsAt,
EndsAt: a.Announcement.EndsAt,
ReadAt: a.ReadAt,
CreatedAt: a.Announcement.CreatedAt,
UpdatedAt: a.Announcement.UpdatedAt,
}
}

View File

@@ -208,6 +208,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
}
}
if scopeLimits := a.GetAntigravityScopeRateLimits(); len(scopeLimits) > 0 {
out.ScopeRateLimits = make(map[string]ScopeRateLimitInfo, len(scopeLimits))
now := time.Now()
for scope, remainingSec := range scopeLimits {
out.ScopeRateLimits[scope] = ScopeRateLimitInfo{
ResetAt: now.Add(time.Duration(remainingSec) * time.Second),
RemainingSec: remainingSec,
}
}
}
return out
}
@@ -325,7 +336,7 @@ func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
}
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
return RedeemCode{
out := RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
@@ -339,6 +350,14 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
// For admin_balance/admin_concurrency types, include notes so users can see
// why they were charged or credited by admin
if (rc.Type == "admin_balance" || rc.Type == "admin_concurrency") && rc.Notes != "" {
out.Notes = &rc.Notes
}
return out
}
// AccountSummaryFromService returns a minimal AccountSummary for usage log display.
@@ -362,6 +381,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
ReasoningEffort: l.ReasoningEffort,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,

View File

@@ -2,9 +2,13 @@ package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -23,14 +27,16 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
@@ -54,21 +60,26 @@ type SystemSettings struct {
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO

View File

@@ -2,6 +2,11 @@ package dto
import "time"
type ScopeRateLimitInfo struct {
ResetAt time.Time `json:"reset_at"`
RemainingSec int64 `json:"remaining_sec"`
}
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
@@ -114,6 +119,9 @@ type Account struct {
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
// Antigravity scope 级限流状态(从 extra 提取)
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
@@ -204,6 +212,10 @@ type RedeemCode struct {
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
// Notes is only populated for admin_balance/admin_concurrency types
// so users can see why they were charged or credited
Notes *string `json:"notes,omitempty"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
@@ -224,6 +236,9 @@ type UsageLog struct {
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
// nil means not provided / not applicable.
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`

View File

@@ -30,6 +30,7 @@ type GatewayHandler struct {
antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService
billingCacheService *service.BillingCacheService
usageService *service.UsageService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
@@ -44,6 +45,7 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
usageService *service.UsageService,
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
@@ -64,6 +66,7 @@ func NewGatewayHandler(
antigravityGatewayService: antigravityGatewayService,
userService: userService,
billingCacheService: billingCacheService,
usageService: usageService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
@@ -537,7 +540,7 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
})
}
// Usage handles getting account balance for CC Switch integration
// Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
@@ -552,7 +555,40 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return
}
// 订阅模式:返回订阅限额信息
// Best-effort: 获取用量统计,失败不影响基础响应
var usageData gin.H
if h.usageService != nil {
dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if err == nil && dashStats != nil {
usageData = gin.H{
"today": gin.H{
"requests": dashStats.TodayRequests,
"input_tokens": dashStats.TodayInputTokens,
"output_tokens": dashStats.TodayOutputTokens,
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
"cache_read_tokens": dashStats.TodayCacheReadTokens,
"total_tokens": dashStats.TodayTokens,
"cost": dashStats.TodayCost,
"actual_cost": dashStats.TodayActualCost,
},
"total": gin.H{
"requests": dashStats.TotalRequests,
"input_tokens": dashStats.TotalInputTokens,
"output_tokens": dashStats.TotalOutputTokens,
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
"cache_read_tokens": dashStats.TotalCacheReadTokens,
"total_tokens": dashStats.TotalTokens,
"cost": dashStats.TotalCost,
"actual_cost": dashStats.TotalActualCost,
},
"average_duration_ms": dashStats.AverageDurationMs,
"rpm": dashStats.Rpm,
"tpm": dashStats.Tpm,
}
}
}
// 订阅模式:返回订阅限额信息 + 用量统计
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware2.GetSubscriptionFromContext(c)
if !ok {
@@ -561,28 +597,46 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
}
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
c.JSON(http.StatusOK, gin.H{
resp := gin.H{
"isValid": true,
"planName": apiKey.Group.Name,
"remaining": remaining,
"unit": "USD",
})
"subscription": gin.H{
"daily_usage_usd": subscription.DailyUsageUSD,
"weekly_usage_usd": subscription.WeeklyUsageUSD,
"monthly_usage_usd": subscription.MonthlyUsageUSD,
"daily_limit_usd": apiKey.Group.DailyLimitUSD,
"weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
"monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
"expires_at": subscription.ExpiresAt,
},
}
if usageData != nil {
resp["usage"] = usageData
}
c.JSON(http.StatusOK, resp)
return
}
// 余额模式:返回钱包余额
// 余额模式:返回钱包余额 + 用量统计
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
}
c.JSON(http.StatusOK, gin.H{
resp := gin.H{
"isValid": true,
"planName": "钱包余额",
"remaining": latestUser.Balance,
"unit": "USD",
})
"balance": latestUser.Balance,
}
if usageData != nil {
resp["usage"] = usageData
}
c.JSON(http.StatusOK, resp)
}
// calculateSubscriptionRemaining 计算订阅剩余可用额度
@@ -738,6 +792,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)

View File

@@ -0,0 +1,122 @@
//go:build unit
package handler
import (
"crypto/sha256"
"encoding/hex"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractGeminiCLISessionHash(t *testing.T) {
tests := []struct {
name string
body string
privilegedUserID string
wantEmpty bool
wantHash string
}{
{
name: "with privileged-user-id and tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: false,
wantHash: func() string {
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}(),
},
{
name: "without privileged-user-id but with tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "",
wantEmpty: false,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "without tmp dir",
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
{
name: "empty body",
body: "",
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 创建测试上下文
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/test", nil)
if tt.privilegedUserID != "" {
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
}
// 调用函数
result := extractGeminiCLISessionHash(c, []byte(tt.body))
// 验证结果
if tt.wantEmpty {
require.Empty(t, result, "expected empty session hash")
} else {
require.NotEmpty(t, result, "expected non-empty session hash")
require.Equal(t, tt.wantHash, result, "session hash mismatch")
}
})
}
}
func TestGeminiCLITmpDirRegex(t *testing.T) {
tests := []struct {
name string
input string
wantMatch bool
wantHash string
}{
{
name: "valid tmp dir path",
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "valid tmp dir path in text",
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "invalid hash length",
input: "/Users/ianshaw/.gemini/tmp/abc123",
wantMatch: false,
},
{
name: "no tmp dir",
input: "Hello world",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
if tt.wantMatch {
require.NotNil(t, match, "expected regex to match")
require.Len(t, match, 2, "expected 2 capture groups")
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
} else {
require.Nil(t, match, "expected regex not to match")
}
})
}
}

View File

@@ -1,11 +1,15 @@
package handler
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io"
"log"
"net/http"
"regexp"
"strings"
"time"
@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin"
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 优先使用 Gemini CLI 的会话标识privileged-user-id + tmp 目录哈希)
sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
// 查询粘性会话绑定的账号 ID用于检测账号切换
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后CLI 继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
} else if sessionBoundAccountID == 0 {
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
sessionBoundAccountID = account.ID
}
// 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
@@ -319,18 +366,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 6) record usage async
// 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
@@ -433,3 +483,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
}
return false
}
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
// 会话标识生成策略:
// 1. 从请求体中提取 tmp 目录哈希64位十六进制
// 2. 从 header 中提取 privileged-user-idUUID
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
//
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
//
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
// 1. 从请求体中提取 tmp 目录哈希
match := geminiCLITmpDirRegex.FindSubmatch(body)
if len(match) < 2 {
return "" // 没有找到 tmp 目录,不使用粘性会话
}
tmpDirHash := string(match[1])
// 2. 提取 privileged-user-id
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
// 3. 组合生成最终的 session hash
if privilegedUserID != "" {
// 组合两个标识符privileged-user-id + tmp 目录哈希
combined := privilegedUserID + ":" + tmpDirHash
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}
// 如果没有 privileged-user-id直接使用 tmp 目录哈希
return tmpDirHash
}

View File

@@ -10,6 +10,7 @@ type AdminHandlers struct {
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
@@ -33,11 +34,13 @@ type Handlers struct {
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Announcement *AnnouncementHandler
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
Setting *SettingHandler
Totp *TotpHandler
}
// BuildInfo contains build-time information

View File

@@ -905,7 +905,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
return true
}
if phase == "billing" || phase == "concurrency" {
@@ -1011,5 +1011,12 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
}
}
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
if settings.IgnoreInvalidApiKeyErrors {
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
return true
}
}
return false
}

View File

@@ -32,20 +32,25 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
})
}

View File

@@ -0,0 +1,181 @@
package handler
import (
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TotpHandler handles TOTP-related requests
type TotpHandler struct {
totpService *service.TotpService
}
// NewTotpHandler creates a new TotpHandler
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
return &TotpHandler{
totpService: totpService,
}
}
// TotpStatusResponse represents the TOTP status response
type TotpStatusResponse struct {
Enabled bool `json:"enabled"`
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
FeatureEnabled bool `json:"feature_enabled"`
}
// GetStatus returns the TOTP status for the current user
// GET /api/v1/user/totp/status
func (h *TotpHandler) GetStatus(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := TotpStatusResponse{
Enabled: status.Enabled,
FeatureEnabled: status.FeatureEnabled,
}
if status.EnabledAt != nil {
ts := status.EnabledAt.Unix()
resp.EnabledAt = &ts
}
response.Success(c, resp)
}
// TotpSetupRequest represents the request to initiate TOTP setup
type TotpSetupRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// TotpSetupResponse represents the TOTP setup response
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"`
}
// InitiateSetup starts the TOTP setup process
// POST /api/v1/user/totp/setup
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpSetupRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body (optional params)
req = TotpSetupRequest{}
}
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, TotpSetupResponse{
Secret: result.Secret,
QRCodeURL: result.QRCodeURL,
SetupToken: result.SetupToken,
Countdown: result.Countdown,
})
}
// TotpEnableRequest represents the request to enable TOTP
type TotpEnableRequest struct {
TotpCode string `json:"totp_code" binding:"required,len=6"`
SetupToken string `json:"setup_token" binding:"required"`
}
// Enable completes the TOTP setup
// POST /api/v1/user/totp/enable
func (h *TotpHandler) Enable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpEnableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// TotpDisableRequest represents the request to disable TOTP
type TotpDisableRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// Disable disables TOTP for the current user
// POST /api/v1/user/totp/disable
func (h *TotpHandler) Disable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpDisableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// GetVerificationMethod returns the verification method for TOTP operations
// GET /api/v1/user/totp/verification-method
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
method := h.totpService.GetVerificationMethod(c.Request.Context())
response.Success(c, method)
}
// SendVerifyCode sends an email verification code for TOTP operations
// POST /api/v1/user/totp/send-code
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}

View File

@@ -13,6 +13,7 @@ func ProvideAdminHandlers(
userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
@@ -32,6 +33,7 @@ func ProvideAdminHandlers(
User: userHandler,
Group: groupHandler,
Account: accountHandler,
Announcement: announcementHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
@@ -66,11 +68,13 @@ func ProvideHandlers(
usageHandler *UsageHandler,
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
announcementHandler *AnnouncementHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
@@ -79,11 +83,13 @@ func ProvideHandlers(
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Announcement: announcementHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
Setting: settingHandler,
Totp: totpHandler,
}
}
@@ -96,9 +102,11 @@ var ProviderSet = wire.NewSet(
NewUsageHandler,
NewRedeemHandler,
NewSubscriptionHandler,
NewAnnouncementHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewSoraGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
// Admin handlers
@@ -106,6 +114,7 @@ var ProviderSet = wire.NewSet(
admin.NewUserHandler,
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewAnnouncementHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,

View File

@@ -33,7 +33,7 @@ const (
"https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent与 Antigravity-Manager 保持一致)
UserAgent = "antigravity/1.11.9 windows/amd64"
UserAgent = "antigravity/1.15.8 windows/amd64"
// Session 过期时间
SessionTTL = 30 * time.Minute

View File

@@ -367,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking,
Thought: true,
}
// 保留原有 signatureClaude 模型需要有效的 signature
if block.Signature != "" {
// signature 处理:
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失)
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature在缺失时降级为普通文本并在上层禁用 thinking mode。
@@ -407,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
},
}
// tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature跳过 thought_signature 校验
// - Claude 模型:透传上游返回的真实 signatureVertex/Google 需要完整签名链路)
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)

View File

@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]`
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != "sig_tool_abc" {
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
}
})
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
contentNoSig := `[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
]`
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
}

View File

@@ -9,11 +9,26 @@ const (
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
//
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
// even if the request doesn't use tools, otherwise upstream may reject the
// request as a non-Claude-Code API request.
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
// Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent": "claude-cli/2.1.22 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0",
"X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Arch": "arm64",
"X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0",
"X-Stainless-Runtime-Version": "v24.13.0",
"X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60",
"X-Stainless-Timeout": "600",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
var ModelIDOverrides = map[string]string{
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
"claude-opus-4-5": "claude-opus-4-5-20251101",
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
}
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
var ModelIDReverseOverrides = map[string]string{
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-opus-4-5-20251101": "claude-opus-4-5",
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
}
// NormalizeModelID 根据 Claude OAuth 规则映射模型
func NormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDOverrides[id]; ok {
return mapped
}
return id
}
// DenormalizeModelID 将上游模型 ID 转换为短名
func DenormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDReverseOverrides[id]; ok {
return mapped
}
return id
}

View File

@@ -2,6 +2,7 @@
package response
import (
"log"
"math"
"net/http"
@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
}
statusCode, status := infraerrors.ToHTTP(err)
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}

View File

@@ -0,0 +1,278 @@
//go:build integration
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests to external services and should be run manually.
//
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
)
// skipIfExternalServiceUnavailable checks if the external service is available.
// If not, it skips the test instead of failing.
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
t.Helper()
if err != nil {
// Check for common network/TLS errors that indicate external service issues
errStr := err.Error()
if strings.Contains(errStr, "certificate has expired") ||
strings.Contains(errStr, "certificate is not yet valid") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "no such host") ||
strings.Contains(errStr, "network is unreachable") ||
strings.Contains(errStr, "timeout") {
t.Skipf("skipping test: external service unavailable: %v", err)
}
t.Fatalf("failed to get fingerprint: %v", err)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}

View File

@@ -1,10 +1,11 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests and should be run manually.
// Unit tests for TLS fingerprint dialer.
// Integration tests that require external network are in dialer_integration_test.go
// and require the 'integration' build tag.
//
// Run with: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (

View File

@@ -809,12 +809,21 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
return err
}
path := "{antigravity_quota_scopes," + string(scope) + "}"
scopeKey := string(scope)
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
`UPDATE accounts SET
extra = jsonb_set(
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
ARRAY['antigravity_quota_scopes', $1]::text[],
$2::jsonb,
true
),
updated_at = NOW(),
last_used_at = NOW()
WHERE id = $3 AND deleted_at IS NULL`,
scopeKey,
raw,
id,
)
@@ -829,6 +838,7 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
}
@@ -849,12 +859,19 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
return err
}
path := "{model_rate_limits," + scope + "}"
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
`UPDATE accounts SET
extra = jsonb_set(
jsonb_set(COALESCE(extra, '{}'::jsonb), '{model_rate_limits}'::text[], COALESCE(extra->'model_rate_limits', '{}'::jsonb), true),
ARRAY['model_rate_limits', $1]::text[],
$2::jsonb,
true
),
updated_at = NOW()
WHERE id = $3 AND deleted_at IS NULL`,
scope,
raw,
id,
)

View File

@@ -0,0 +1,95 @@
package repository
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// AESEncryptor implements SecretEncryptor using AES-256-GCM
type AESEncryptor struct {
key []byte
}
// NewAESEncryptor creates a new AES encryptor
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
}
return &AESEncryptor{key: key}, nil
}
// Encrypt encrypts plaintext using AES-256-GCM
// Output format: base64(nonce + ciphertext + tag)
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
// Generate a random nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("generate nonce: %w", err)
}
// Encrypt the plaintext
// Seal appends the ciphertext and tag to the nonce
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode as base64
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts ciphertext using AES-256-GCM
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
// Decode from base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("decode base64: %w", err)
}
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
// Extract nonce and ciphertext
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}

View File

@@ -0,0 +1,83 @@
package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type announcementReadRepository struct {
client *dbent.Client
}
func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementReadRepository {
return &announcementReadRepository{client: client}
}
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
client := clientFromContext(ctx, r.client)
return client.AnnouncementRead.Create().
SetAnnouncementID(announcementID).
SetUserID(userID).
SetReadAt(readAt).
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
DoNothing().
Exec(ctx)
}
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
if len(announcementIDs) == 0 {
return map[int64]time.Time{}, nil
}
rows, err := r.client.AnnouncementRead.Query().
Where(
announcementread.UserIDEQ(userID),
announcementread.AnnouncementIDIn(announcementIDs...),
).
All(ctx)
if err != nil {
return nil, err
}
out := make(map[int64]time.Time, len(rows))
for i := range rows {
out[rows[i].AnnouncementID] = rows[i].ReadAt
}
return out, nil
}
func (r *announcementReadRepository) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
if len(userIDs) == 0 {
return map[int64]time.Time{}, nil
}
rows, err := r.client.AnnouncementRead.Query().
Where(
announcementread.AnnouncementIDEQ(announcementID),
announcementread.UserIDIn(userIDs...),
).
All(ctx)
if err != nil {
return nil, err
}
out := make(map[int64]time.Time, len(rows))
for i := range rows {
out[rows[i].UserID] = rows[i].ReadAt
}
return out, nil
}
func (r *announcementReadRepository) CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error) {
count, err := r.client.AnnouncementRead.Query().
Where(announcementread.AnnouncementIDEQ(announcementID)).
Count(ctx)
if err != nil {
return 0, err
}
return int64(count), nil
}

View File

@@ -0,0 +1,194 @@
package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type announcementRepository struct {
client *dbent.Client
}
func NewAnnouncementRepository(client *dbent.Client) service.AnnouncementRepository {
return &announcementRepository{client: client}
}
func (r *announcementRepository) Create(ctx context.Context, a *service.Announcement) error {
client := clientFromContext(ctx, r.client)
builder := client.Announcement.Create().
SetTitle(a.Title).
SetContent(a.Content).
SetStatus(a.Status).
SetTargeting(a.Targeting)
if a.StartsAt != nil {
builder.SetStartsAt(*a.StartsAt)
}
if a.EndsAt != nil {
builder.SetEndsAt(*a.EndsAt)
}
if a.CreatedBy != nil {
builder.SetCreatedBy(*a.CreatedBy)
}
if a.UpdatedBy != nil {
builder.SetUpdatedBy(*a.UpdatedBy)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
applyAnnouncementEntityToService(a, created)
return nil
}
func (r *announcementRepository) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
m, err := r.client.Announcement.Query().
Where(announcement.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
}
return announcementEntityToService(m), nil
}
func (r *announcementRepository) Update(ctx context.Context, a *service.Announcement) error {
client := clientFromContext(ctx, r.client)
builder := client.Announcement.UpdateOneID(a.ID).
SetTitle(a.Title).
SetContent(a.Content).
SetStatus(a.Status).
SetTargeting(a.Targeting)
if a.StartsAt != nil {
builder.SetStartsAt(*a.StartsAt)
} else {
builder.ClearStartsAt()
}
if a.EndsAt != nil {
builder.SetEndsAt(*a.EndsAt)
} else {
builder.ClearEndsAt()
}
if a.CreatedBy != nil {
builder.SetCreatedBy(*a.CreatedBy)
} else {
builder.ClearCreatedBy()
}
if a.UpdatedBy != nil {
builder.SetUpdatedBy(*a.UpdatedBy)
} else {
builder.ClearUpdatedBy()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
}
a.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *announcementRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.Announcement.Delete().Where(announcement.IDEQ(id)).Exec(ctx)
return err
}
func (r *announcementRepository) List(
ctx context.Context,
params pagination.PaginationParams,
filters service.AnnouncementListFilters,
) ([]service.Announcement, *pagination.PaginationResult, error) {
q := r.client.Announcement.Query()
if filters.Status != "" {
q = q.Where(announcement.StatusEQ(filters.Status))
}
if filters.Search != "" {
q = q.Where(
announcement.Or(
announcement.TitleContainsFold(filters.Search),
announcement.ContentContainsFold(filters.Search),
),
)
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
items, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(announcement.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
out := announcementEntitiesToService(items)
return out, paginationResultFromTotal(int64(total), params), nil
}
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
q := r.client.Announcement.Query().
Where(
announcement.StatusEQ(service.AnnouncementStatusActive),
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
).
Order(dbent.Desc(announcement.FieldID))
items, err := q.All(ctx)
if err != nil {
return nil, err
}
return announcementEntitiesToService(items), nil
}
func applyAnnouncementEntityToService(dst *service.Announcement, src *dbent.Announcement) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
if m == nil {
return nil
}
return &service.Announcement{
ID: m.ID,
Title: m.Title,
Content: m.Content,
Status: m.Status,
Targeting: m.Targeting,
StartsAt: m.StartsAt,
EndsAt: m.EndsAt,
CreatedBy: m.CreatedBy,
UpdatedBy: m.UpdatedBy,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func announcementEntitiesToService(models []*dbent.Announcement) []service.Announcement {
out := make([]service.Announcement, 0, len(models))
for i := range models {
if s := announcementEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}

View File

@@ -391,17 +391,20 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}

View File

@@ -9,13 +9,27 @@ import (
"github.com/redis/go-redis/v9"
)
const verifyCodeKeyPrefix = "verify_code:"
const (
verifyCodeKeyPrefix = "verify_code:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
)
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
// passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string {
return passwordResetKeyPrefix + email
}
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string {
return passwordResetSentAtKeyPrefix + email
}
type emailCache struct {
rdb *redis.Client
}
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err()
}
// Password reset token methods
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
key := passwordResetKey(email)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data service.PasswordResetTokenData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
key := passwordResetKey(email)
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
key := passwordResetKey(email)
return c.rdb.Del(ctx, key).Err()
}
// Password reset email cooldown methods
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
key := passwordResetSentAtKey(email)
exists, err := c.rdb.Exists(ctx, key).Result()
return err == nil && exists > 0
}
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
key := passwordResetSentAtKey(email)
return c.rdb.Set(ctx, key, "1", ttl).Err()
}

View File

@@ -433,3 +433,61 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
return counts, nil
}
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID去重
func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
if len(groupIDs) == 0 {
return nil, nil
}
rows, err := r.sql.QueryContext(
ctx,
"SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id",
pq.Array(groupIDs),
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var accountIDs []int64
for rows.Next() {
var accountID int64
if err := rows.Scan(&accountID); err != nil {
return nil, err
}
accountIDs = append(accountIDs, accountID)
}
if err := rows.Err(); err != nil {
return nil, err
}
return accountIDs, nil
}
// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定)
func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
if len(accountIDs) == 0 {
return nil
}
// 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定
_, err := r.sql.ExecContext(
ctx,
`INSERT INTO account_groups (account_id, group_id, priority, created_at)
SELECT unnest($1::bigint[]), $2, 50, NOW()
ON CONFLICT (account_id, group_id) DO NOTHING`,
pq.Array(accountIDs),
groupID,
)
if err != nil {
return err
}
// 发送调度器事件
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
}
return nil
}

View File

@@ -2,11 +2,11 @@ package repository
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
@@ -22,7 +22,7 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(s.tokenURL, proxyURL)
client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -39,23 +39,24 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(s.tokenURL, proxyURL)
client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -67,29 +68,25 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
forceHTTP2 := false
if parsedURL, err := url.Parse(tokenURL); err == nil {
forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
}
func createOpenAIReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
ForceHTTP2: forceHTTP2,
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
})
}

View File

@@ -28,7 +28,6 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
}
return &proxyProbeService{
ipInfoURL: defaultIPInfoURL,
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
validateResolvedIP: validateResolvedIP,
@@ -36,12 +35,20 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
}
const (
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
defaultProxyProbeTimeout = 30 * time.Second
)
// probeURLs 按优先级排列的探测 URL 列表
// 某些 AI API 专用代理只允许访问特定域名,因此需要多个备选
var probeURLs = []struct {
url string
parser string // "ip-api" or "httpbin"
}{
{"http://ip-api.com/json/?lang=zh-CN", "ip-api"},
{"http://httpbin.org/ip", "httpbin"},
}
type proxyProbeService struct {
ipInfoURL string
insecureSkipVerify bool
allowPrivateHosts bool
validateResolvedIP bool
@@ -60,8 +67,21 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
}
var lastErr error
for _, probe := range probeURLs {
exitInfo, latencyMs, err := s.probeWithURL(ctx, client, probe.url, probe.parser)
if err == nil {
return exitInfo, latencyMs, nil
}
lastErr = err
}
return nil, 0, fmt.Errorf("all probe URLs failed, last error: %w", lastErr)
}
func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Client, url string, parser string) (*service.ProxyExitInfo, int64, error) {
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}
@@ -78,6 +98,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
}
switch parser {
case "ip-api":
return s.parseIPAPI(body, latencyMs)
case "httpbin":
return s.parseHTTPBin(body, latencyMs)
default:
return nil, latencyMs, fmt.Errorf("unknown parser: %s", parser)
}
}
func (s *proxyProbeService) parseIPAPI(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
var ipInfo struct {
Status string `json:"status"`
Message string `json:"message"`
@@ -89,13 +125,12 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
CountryCode string `json:"countryCode"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
preview := string(body)
if len(preview) > 200 {
preview = preview[:200] + "..."
}
return nil, latencyMs, fmt.Errorf("failed to parse response: %w (body: %s)", err, preview)
}
if strings.ToLower(ipInfo.Status) != "success" {
if ipInfo.Message == "" {
@@ -116,3 +151,19 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
CountryCode: ipInfo.CountryCode,
}, latencyMs, nil
}
func (s *proxyProbeService) parseHTTPBin(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
// httpbin.org/ip 返回格式: {"origin": "1.2.3.4"}
var result struct {
Origin string `json:"origin"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse httpbin response: %w", err)
}
if result.Origin == "" {
return nil, latencyMs, fmt.Errorf("httpbin: no IP found in response")
}
return &service.ProxyExitInfo{
IP: result.Origin,
}, latencyMs, nil
}

View File

@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/require"
@@ -21,7 +22,6 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
allowPrivateHosts: true,
}
}
@@ -49,12 +49,16 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
require.ErrorContains(s.T(), err, "failed to create proxy client")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
seen := make(chan string, 1)
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_IPAPI() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
// 检查是否是 ip-api 请求
if strings.Contains(r.RequestURI, "ip-api.com") {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
return
}
// 其他请求返回错误
w.WriteHeader(http.StatusServiceUnavailable)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
@@ -65,45 +69,59 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
require.Equal(s.T(), "CC", info.CountryCode)
// Verify proxy received the request
select {
case uri := <-seen:
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_HTTPBinFallback() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ip-api 失败
if strings.Contains(r.RequestURI, "ip-api.com") {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
// httpbin 成功
if strings.Contains(r.RequestURI, "httpbin.org") {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"origin": "5.6.7.8"}`)
return
}
w.WriteHeader(http.StatusServiceUnavailable)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.NoError(s.T(), err, "ProbeProxy should fallback to httpbin")
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
require.Equal(s.T(), "5.6.7.8", info.IP)
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_AllFailed() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status: 503")
require.ErrorContains(s.T(), err, "all probe URLs failed")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
if strings.Contains(r.RequestURI, "ip-api.com") {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
return
}
// httpbin 也返回无效响应
if strings.Contains(r.RequestURI, "httpbin.org") {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
return
}
w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "failed to parse response")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
s.prober.ipInfoURL = "://invalid-url"
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
require.ErrorContains(s.T(), err, "all probe URLs failed")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
@@ -114,6 +132,40 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
require.Error(s.T(), err, "expected error when proxy server is closed")
}
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Success() {
body := []byte(`{"status":"success","query":"1.2.3.4","city":"Beijing","regionName":"Beijing","country":"China","countryCode":"CN"}`)
info, latencyMs, err := s.prober.parseIPAPI(body, 100)
require.NoError(s.T(), err)
require.Equal(s.T(), int64(100), latencyMs)
require.Equal(s.T(), "1.2.3.4", info.IP)
require.Equal(s.T(), "Beijing", info.City)
require.Equal(s.T(), "Beijing", info.Region)
require.Equal(s.T(), "China", info.Country)
require.Equal(s.T(), "CN", info.CountryCode)
}
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Failure() {
body := []byte(`{"status":"fail","message":"rate limited"}`)
_, _, err := s.prober.parseIPAPI(body, 100)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "rate limited")
}
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_Success() {
body := []byte(`{"origin": "9.8.7.6"}`)
info, latencyMs, err := s.prober.parseHTTPBin(body, 50)
require.NoError(s.T(), err)
require.Equal(s.T(), int64(50), latencyMs)
require.Equal(s.T(), "9.8.7.6", info.IP)
}
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_NoIP() {
body := []byte(`{"origin": ""}`)
_, _, err := s.prober.parseHTTPBin(body, 50)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "no IP found")
}
func TestProxyProbeServiceSuite(t *testing.T) {
suite.Run(t, new(ProxyProbeServiceSuite))
}

View File

@@ -202,6 +202,57 @@ func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, lim
return redeemCodeEntitiesToService(codes), nil
}
// ListByUserPaginated returns paginated balance/concurrency history for a user.
// Supports optional type filter (e.g. "balance", "admin_balance", "concurrency", "admin_concurrency", "subscription").
func (r *redeemCodeRepository) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
q := r.client.RedeemCode.Query().
Where(redeemcode.UsedByEQ(userID))
// Optional type filter
if codeType != "" {
q = q.Where(redeemcode.TypeEQ(codeType))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
codes, err := q.
WithGroup().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(redeemcode.FieldUsedAt)).
All(ctx)
if err != nil {
return nil, nil, err
}
return redeemCodeEntitiesToService(codes), paginationResultFromTotal(int64(total), params), nil
}
// SumPositiveBalanceByUser returns total recharged amount (sum of value > 0 where type is balance/admin_balance).
func (r *redeemCodeRepository) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
var result []struct {
Sum float64 `json:"sum"`
}
err := r.client.RedeemCode.Query().
Where(
redeemcode.UsedByEQ(userID),
redeemcode.ValueGT(0),
redeemcode.TypeIn("balance", "admin_balance"),
).
Aggregate(dbent.As(dbent.Sum(redeemcode.FieldValue), "sum")).
Scan(ctx, &result)
if err != nil {
return 0, err
}
if len(result) == 0 {
return 0, nil
}
return result[0].Sum, nil
}
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
if m == nil {
return nil

View File

@@ -1,6 +1,7 @@
package repository
import (
"crypto/tls"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -26,7 +27,7 @@ func InitRedis(cfg *config.Config) *redis.Client {
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func buildRedisOptions(cfg *config.Config) *redis.Options {
return &redis.Options{
opts := &redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
@@ -36,4 +37,13 @@ func buildRedisOptions(cfg *config.Config) *redis.Options {
PoolSize: cfg.Redis.PoolSize, // 连接池大小
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
}
if cfg.Redis.EnableTLS {
opts.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: cfg.Redis.Host,
}
}
return opts
}

View File

@@ -32,4 +32,16 @@ func TestBuildRedisOptions(t *testing.T) {
require.Equal(t, 4*time.Second, opts.WriteTimeout)
require.Equal(t, 100, opts.PoolSize)
require.Equal(t, 10, opts.MinIdleConns)
require.Nil(t, opts.TLSConfig)
// Test case with TLS enabled
cfgTLS := &config.Config{
Redis: config.RedisConfig{
Host: "localhost",
EnableTLS: true,
},
}
optsTLS := buildRedisOptions(cfgTLS)
require.NotNil(t, optsTLS.TLSConfig)
require.Equal(t, "localhost", optsTLS.TLSConfig.ServerName)
}

View File

@@ -77,21 +77,9 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
require.Equal(t, "2", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
client := createOpenAIReqClient("http://proxy.local:8080")
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}

View File

@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
return nil, false, err
}
if len(ids) == 0 {
return []*service.Account{}, true, nil
// 空快照视为缓存未命中,触发数据库回退查询
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
return nil, false, nil
}
keys := make([]string, 0, len(ids))

View File

@@ -0,0 +1,149 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
totpSetupKeyPrefix = "totp:setup:"
totpLoginKeyPrefix = "totp:login:"
totpAttemptsKeyPrefix = "totp:attempts:"
totpAttemptsTTL = 15 * time.Minute
)
// TotpCache implements service.TotpCache using Redis
type TotpCache struct {
rdb *redis.Client
}
// NewTotpCache creates a new TOTP cache
func NewTotpCache(rdb *redis.Client) service.TotpCache {
return &TotpCache{rdb: rdb}
}
// GetSetupSession retrieves a TOTP setup session
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get setup session: %w", err)
}
var session service.TotpSetupSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal setup session: %w", err)
}
return &session, nil
}
// SetSetupSession stores a TOTP setup session
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal setup session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set setup session: %w", err)
}
return nil
}
// DeleteSetupSession deletes a TOTP setup session
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
// GetLoginSession retrieves a TOTP login session
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
key := totpLoginKeyPrefix + tempToken
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get login session: %w", err)
}
var session service.TotpLoginSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal login session: %w", err)
}
return &session, nil
}
// SetLoginSession stores a TOTP login session
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
key := totpLoginKeyPrefix + tempToken
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal login session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set login session: %w", err)
}
return nil
}
// DeleteLoginSession deletes a TOTP login session
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
key := totpLoginKeyPrefix + tempToken
return c.rdb.Del(ctx, key).Err()
}
// IncrementVerifyAttempts increments the verify attempt counter
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
// Use pipeline for atomic increment and set TTL
pipe := c.rdb.Pipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, totpAttemptsTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("increment verify attempts: %w", err)
}
count, err := incrCmd.Result()
if err != nil {
return 0, fmt.Errorf("get increment result: %w", err)
}
return int(count), nil
}
// GetVerifyAttempts gets the current verify attempt count
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
count, err := c.rdb.Get(ctx, key).Int()
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, fmt.Errorf("get verify attempts: %w", err)
}
return count, nil
}
// ClearVerifyAttempts clears the verify attempt counter
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at"
type usageLogRepository struct {
client *dbent.Client
@@ -115,6 +115,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
image_count,
image_size,
media_type,
reasoning_effort,
created_at
) VALUES (
$1, $2, $3, $4, $5,
@@ -122,7 +123,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -136,6 +137,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
mediaType := nullString(log.MediaType)
reasoningEffort := nullString(log.ReasoningEffort)
var requestIDArg any
if requestID != "" {
@@ -173,6 +175,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.ImageCount,
imageSize,
mediaType,
reasoningEffort,
createdAt,
}
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
@@ -2094,6 +2097,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
imageCount int
imageSize sql.NullString
mediaType sql.NullString
reasoningEffort sql.NullString
createdAt time.Time
)
@@ -2129,6 +2133,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&imageCount,
&imageSize,
&mediaType,
&reasoningEffort,
&createdAt,
); err != nil {
return nil, err
@@ -2191,6 +2196,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if mediaType.Valid {
log.MediaType = &mediaType.String
}
if reasoningEffort.Valid {
log.ReasoningEffort = &reasoningEffort.String
}
return log, nil
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
@@ -189,6 +190,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
dbuser.Or(
dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(filters.Search),
dbuser.NotesContainsFold(filters.Search),
),
)
}
@@ -466,3 +468,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
update := client.User.UpdateOneID(userID)
if encryptedSecret == nil {
update = update.ClearTotpSecretEncrypted()
} else {
update = update.SetTotpSecretEncrypted(*encryptedSecret)
}
_, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(true).
SetTotpEnabledAt(time.Now()).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(false).
ClearTotpEnabledAt().
ClearTotpSecretEncrypted().
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}

View File

@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
if status != "" {
// Status filtering with real-time expiration check
now := time.Now()
switch status {
case service.SubscriptionStatusActive:
// Active: status is active AND not yet expired
q = q.Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(now),
)
case service.SubscriptionStatusExpired:
// Expired: status is expired OR (status is active but already expired)
q = q.Where(
usersubscription.Or(
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
usersubscription.And(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(now),
),
),
)
case "":
// No filter
default:
// Other status (e.g., revoked)
q = q.Where(usersubscription.StatusEQ(status))
}
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err
}
// Apply sorting
q = q.WithUser().WithGroup().WithAssignedByUser()
// Determine sort field
var field string
switch sortBy {
case "expires_at":
field = usersubscription.FieldExpiresAt
case "status":
field = usersubscription.FieldStatus
default:
field = usersubscription.FieldCreatedAt
}
// Determine sort order (default: desc)
if sortOrder == "asc" && sortBy != "" {
q = q.Order(dbent.Asc(field))
} else {
q = q.Order(dbent.Desc(field))
}
subs, err := q.
WithUser().
WithGroup().
WithAssignedByUser().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)

View File

@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)

View File

@@ -57,6 +57,8 @@ var ProviderSet = wire.NewSet(
NewProxyRepository,
NewRedeemCodeRepository,
NewPromoCodeRepository,
NewAnnouncementRepository,
NewAnnouncementReadRepository,
NewUsageLogRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,
@@ -83,6 +85,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
NewTotpCache,
// Encryptors
NewAESEncryptor,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,

View File

@@ -201,7 +201,7 @@ func TestAPIContracts(t *testing.T) {
UserID: 1,
GroupID: 10,
StartsAt: deps.now,
ExpiresAt: deps.now.Add(24 * time.Hour),
ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status: service.SubscriptionStatusActive,
DailyUsageUSD: 1.23,
WeeklyUsageUSD: 2.34,
@@ -226,7 +226,7 @@ func TestAPIContracts(t *testing.T) {
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
"expires_at": "2025-01-03T03:04:05Z",
"expires_at": "2099-01-02T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
@@ -457,6 +457,9 @@ func TestAPIContracts(t *testing.T) {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
@@ -490,8 +493,11 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"invitation_code_enabled": false,
"home_content": "",
"hide_ccs_import_button": false
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": ""
}
}`,
},
@@ -600,7 +606,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
@@ -759,6 +765,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
type stubApiKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
@@ -868,6 +886,14 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
return 0, errors.New("not implemented")
}
func (stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return errors.New("not implemented")
}
func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
type stubAccountRepo struct {
bulkUpdateIDs []int64
}
@@ -1133,6 +1159,14 @@ func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit
return append([]service.RedeemCode(nil), codes...), nil
}
func (stubRedeemCodeRepo) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubRedeemCodeRepo) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct {
byUser map[int64][]service.UserSubscription
activeByUser map[int64][]service.UserSubscription
@@ -1185,7 +1219,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {

View File

@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}

View File

@@ -29,6 +29,9 @@ func RegisterAdminRoutes(
// 账号管理
registerAccountRoutes(admin, h)
// 公告管理
registerAnnouncementRoutes(admin, h)
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
@@ -172,6 +175,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
@@ -229,6 +233,18 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func registerAnnouncementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
announcements := admin.Group("/announcements")
{
announcements.GET("", h.Admin.Announcement.List)
announcements.POST("", h.Admin.Announcement.Create)
announcements.GET("/:id", h.Admin.Announcement.GetByID)
announcements.PUT("/:id", h.Admin.Announcement.Update)
announcements.DELETE("/:id", h.Admin.Announcement.Delete)
announcements.GET("/:id/read-status", h.Admin.Announcement.ListReadStatus)
}
}
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
openai := admin.Group("/openai")
{

View File

@@ -26,11 +26,24 @@ func RegisterAuthRoutes(
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode)
// 邀请码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/validate-invitation-code", rateLimiter.LimitWithOptions("validate-invitation", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidateInvitationCode)
// 忘记密码接口添加速率限制:每分钟最多 5 次Redis 故障时 fail-close
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ForgotPassword)
// 重置密码接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
}

View File

@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
// TOTP 双因素认证
totp := user.Group("/totp")
{
totp.GET("/status", h.Totp.GetStatus)
totp.GET("/verification-method", h.Totp.GetVerificationMethod)
totp.POST("/send-code", h.Totp.SendVerifyCode)
totp.POST("/setup", h.Totp.InitiateSetup)
totp.POST("/enable", h.Totp.Enable)
totp.POST("/disable", h.Totp.Disable)
}
}
// API Key管理
@@ -53,6 +64,13 @@ func RegisterUserRoutes(
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
}
// 公告(用户可见)
announcements := authenticated.Group("/announcements")
{
announcements.GET("", h.Announcement.List)
announcements.POST("/:id/read", h.Announcement.MarkRead)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{

View File

@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
return ""
}
func (a *Account) GetClaudeUserID() string {
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
return v
}
return ""
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false

View File

@@ -124,7 +124,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system": []map[string]any{
{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{
"type": "ephemeral",
},

View File

@@ -22,6 +22,10 @@ type AdminService interface {
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
@@ -115,6 +119,8 @@ type CreateGroupInput struct {
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
type UpdateGroupInput struct {
@@ -142,6 +148,8 @@ type UpdateGroupInput struct {
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
type CreateAccountInput struct {
@@ -535,6 +543,21 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
}, nil
}
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
if err != nil {
return nil, 0, 0, err
}
// Aggregate total recharged amount (only once, regardless of type filter)
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil {
return nil, 0, 0, err
}
return codes, result.Total, totalRecharged, nil
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
@@ -589,6 +612,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
}
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var accountIDsToCopy []int64
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
seen := make(map[int64]struct{})
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
if _, exists := seen[srcGroupID]; !exists {
seen[srcGroupID] = struct{}{}
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
}
}
// 校验源分组的平台是否与新分组一致
for _, srcGroupID := range uniqueSourceGroupIDs {
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
if err != nil {
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
}
if srcGroup.Platform != platform {
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform)
}
}
// 获取所有源分组的账号(去重)
var err error
accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
if err != nil {
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
}
}
group := &Group{
Name: input.Name,
Description: input.Description,
@@ -614,6 +669,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
// 如果有需要复制的账号,绑定到新分组
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
return nil, fmt.Errorf("failed to bind accounts to new group: %w", err)
}
group.AccountCount = int64(len(accountIDsToCopy))
}
return group, nil
}
@@ -761,6 +825,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
seen := make(map[int64]struct{})
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
// 校验:源分组不能是自身
if srcGroupID == id {
return nil, fmt.Errorf("cannot copy accounts from self")
}
// 去重
if _, exists := seen[srcGroupID]; !exists {
seen[srcGroupID] = struct{}{}
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
}
}
// 校验源分组的平台是否与当前分组一致
for _, srcGroupID := range uniqueSourceGroupIDs {
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
if err != nil {
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
}
if srcGroup.Platform != group.Platform {
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform)
}
}
// 获取所有源分组的账号(去重)
accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
if err != nil {
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
}
// 先清空当前分组的所有账号绑定
if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil {
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
}
// 再绑定源分组的账号
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
return nil, fmt.Errorf("failed to bind accounts to group: %w", err)
}
}
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}

View File

@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
type groupRepoStub struct {
affectedUserIDs []int64
deleteErr error
@@ -152,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
type proxyRepoStub struct {
deleteErr error
countErr error
@@ -262,6 +282,14 @@ func (s *redeemRepoStub) ListByUser(ctx context.Context, userID int64, limit int
panic("unexpected ListByUser call")
}
func (s *redeemRepoStub) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListByUserPaginated call")
}
func (s *redeemRepoStub) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
panic("unexpected SumPositiveBalanceByUser call")
}
type subscriptionInvalidateCall struct {
userID int64
groupID int64

View File

@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{}
@@ -378,3 +386,11 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}

View File

@@ -152,6 +152,14 @@ func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params p
return s.listWithFiltersCodes, result, nil
}
func (s *redeemRepoStubForAdminList) ListByUserPaginated(_ context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListByUserPaginated call")
}
func (s *redeemRepoStubForAdminList) SumPositiveBalanceByUser(_ context.Context, userID int64) (float64, error) {
panic("unexpected SumPositiveBalanceByUser call")
}
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &accountRepoStubForAdminList{

View File

@@ -0,0 +1,64 @@
package service
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
AnnouncementStatusDraft = domain.AnnouncementStatusDraft
AnnouncementStatusActive = domain.AnnouncementStatusActive
AnnouncementStatusArchived = domain.AnnouncementStatusArchived
)
const (
AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance
)
const (
AnnouncementOperatorIn = domain.AnnouncementOperatorIn
AnnouncementOperatorGT = domain.AnnouncementOperatorGT
AnnouncementOperatorGTE = domain.AnnouncementOperatorGTE
AnnouncementOperatorLT = domain.AnnouncementOperatorLT
AnnouncementOperatorLTE = domain.AnnouncementOperatorLTE
AnnouncementOperatorEQ = domain.AnnouncementOperatorEQ
)
var (
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
)
type AnnouncementTargeting = domain.AnnouncementTargeting
type AnnouncementConditionGroup = domain.AnnouncementConditionGroup
type AnnouncementCondition = domain.AnnouncementCondition
type Announcement = domain.Announcement
type AnnouncementListFilters struct {
Status string
Search string
}
type AnnouncementRepository interface {
Create(ctx context.Context, a *Announcement) error
GetByID(ctx context.Context, id int64) (*Announcement, error)
Update(ctx context.Context, a *Announcement) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error)
ListActive(ctx context.Context, now time.Time) ([]Announcement, error)
}
type AnnouncementReadRepository interface {
MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error
GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error)
GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error)
CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error)
}

View File

@@ -0,0 +1,378 @@
package service
import (
"context"
"fmt"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type AnnouncementService struct {
announcementRepo AnnouncementRepository
readRepo AnnouncementReadRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
}
func NewAnnouncementService(
announcementRepo AnnouncementRepository,
readRepo AnnouncementReadRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
) *AnnouncementService {
return &AnnouncementService{
announcementRepo: announcementRepo,
readRepo: readRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
}
}
type CreateAnnouncementInput struct {
Title string
Content string
Status string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
ActorID *int64 // 管理员用户ID
}
type UpdateAnnouncementInput struct {
Title *string
Content *string
Status *string
Targeting *AnnouncementTargeting
StartsAt **time.Time
EndsAt **time.Time
ActorID *int64 // 管理员用户ID
}
type UserAnnouncement struct {
Announcement Announcement
ReadAt *time.Time
}
type AnnouncementUserReadStatus struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
Balance float64 `json:"balance"`
Eligible bool `json:"eligible"`
ReadAt *time.Time `json:"read_at,omitempty"`
}
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("create announcement: nil input")
}
title := strings.TrimSpace(input.Title)
content := strings.TrimSpace(input.Content)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("create announcement: invalid title")
}
if content == "" {
return nil, fmt.Errorf("create announcement: content is required")
}
status := strings.TrimSpace(input.Status)
if status == "" {
status = AnnouncementStatusDraft
}
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("create announcement: invalid status")
}
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
if err != nil {
return nil, err
}
if input.StartsAt != nil && input.EndsAt != nil {
if !input.StartsAt.Before(*input.EndsAt) {
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
}
}
a := &Announcement{
Title: title,
Content: content,
Status: status,
Targeting: targeting,
StartsAt: input.StartsAt,
EndsAt: input.EndsAt,
}
if input.ActorID != nil && *input.ActorID > 0 {
a.CreatedBy = input.ActorID
a.UpdatedBy = input.ActorID
}
if err := s.announcementRepo.Create(ctx, a); err != nil {
return nil, fmt.Errorf("create announcement: %w", err)
}
return a, nil
}
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("update announcement: nil input")
}
a, err := s.announcementRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Title != nil {
title := strings.TrimSpace(*input.Title)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("update announcement: invalid title")
}
a.Title = title
}
if input.Content != nil {
content := strings.TrimSpace(*input.Content)
if content == "" {
return nil, fmt.Errorf("update announcement: content is required")
}
a.Content = content
}
if input.Status != nil {
status := strings.TrimSpace(*input.Status)
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("update announcement: invalid status")
}
a.Status = status
}
if input.Targeting != nil {
targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate()
if err != nil {
return nil, err
}
a.Targeting = targeting
}
if input.StartsAt != nil {
a.StartsAt = *input.StartsAt
}
if input.EndsAt != nil {
a.EndsAt = *input.EndsAt
}
if a.StartsAt != nil && a.EndsAt != nil {
if !a.StartsAt.Before(*a.EndsAt) {
return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
}
}
if input.ActorID != nil && *input.ActorID > 0 {
a.UpdatedBy = input.ActorID
}
if err := s.announcementRepo.Update(ctx, a); err != nil {
return nil, fmt.Errorf("update announcement: %w", err)
}
return a, nil
}
func (s *AnnouncementService) Delete(ctx context.Context, id int64) error {
if err := s.announcementRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete announcement: %w", err)
}
return nil
}
func (s *AnnouncementService) GetByID(ctx context.Context, id int64) (*Announcement, error) {
return s.announcementRepo.GetByID(ctx, id)
}
func (s *AnnouncementService) List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
return s.announcementRepo.List(ctx, params, filters)
}
func (s *AnnouncementService) ListForUser(ctx context.Context, userID int64, unreadOnly bool) ([]UserAnnouncement, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
for i := range activeSubs {
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
}
now := time.Now()
anns, err := s.announcementRepo.ListActive(ctx, now)
if err != nil {
return nil, fmt.Errorf("list active announcements: %w", err)
}
visible := make([]Announcement, 0, len(anns))
ids := make([]int64, 0, len(anns))
for i := range anns {
a := anns[i]
if !a.IsActiveAt(now) {
continue
}
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
continue
}
visible = append(visible, a)
ids = append(ids, a.ID)
}
if len(visible) == 0 {
return []UserAnnouncement{}, nil
}
readMap, err := s.readRepo.GetReadMapByUser(ctx, userID, ids)
if err != nil {
return nil, fmt.Errorf("get read map: %w", err)
}
out := make([]UserAnnouncement, 0, len(visible))
for i := range visible {
a := visible[i]
readAt, ok := readMap[a.ID]
if unreadOnly && ok {
continue
}
var ptr *time.Time
if ok {
t := readAt
ptr = &t
}
out = append(out, UserAnnouncement{
Announcement: a,
ReadAt: ptr,
})
}
// 未读优先、同状态按创建时间倒序
sort.Slice(out, func(i, j int) bool {
ai, aj := out[i], out[j]
if (ai.ReadAt == nil) != (aj.ReadAt == nil) {
return ai.ReadAt == nil
}
return ai.Announcement.ID > aj.Announcement.ID
})
return out, nil
}
func (s *AnnouncementService) MarkRead(ctx context.Context, userID, announcementID int64) error {
// 安全:仅允许标记当前用户“可见”的公告
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
a, err := s.announcementRepo.GetByID(ctx, announcementID)
if err != nil {
return err
}
now := time.Now()
if !a.IsActiveAt(now) {
return ErrAnnouncementNotFound
}
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
for i := range activeSubs {
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
}
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
return ErrAnnouncementNotFound
}
if err := s.readRepo.MarkRead(ctx, announcementID, userID, now); err != nil {
return fmt.Errorf("mark read: %w", err)
}
return nil
}
func (s *AnnouncementService) ListUserReadStatus(
ctx context.Context,
announcementID int64,
params pagination.PaginationParams,
search string,
) ([]AnnouncementUserReadStatus, *pagination.PaginationResult, error) {
ann, err := s.announcementRepo.GetByID(ctx, announcementID)
if err != nil {
return nil, nil, err
}
filters := UserListFilters{
Search: strings.TrimSpace(search),
}
users, page, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)
}
userIDs := make([]int64, 0, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
}
readMap, err := s.readRepo.GetReadMapByUsers(ctx, announcementID, userIDs)
if err != nil {
return nil, nil, fmt.Errorf("get read map: %w", err)
}
out := make([]AnnouncementUserReadStatus, 0, len(users))
for i := range users {
u := users[i]
subs, err := s.userSubRepo.ListActiveByUserID(ctx, u.ID)
if err != nil {
return nil, nil, fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(subs))
for j := range subs {
activeGroupIDs[subs[j].GroupID] = struct{}{}
}
readAt, ok := readMap[u.ID]
var ptr *time.Time
if ok {
t := readAt
ptr = &t
}
out = append(out, AnnouncementUserReadStatus{
UserID: u.ID,
Email: u.Email,
Username: u.Username,
Balance: u.Balance,
Eligible: domain.AnnouncementTargeting(ann.Targeting).Matches(u.Balance, activeGroupIDs),
ReadAt: ptr,
})
}
return out, page, nil
}
func isValidAnnouncementStatus(status string) bool {
switch status {
case AnnouncementStatusDraft, AnnouncementStatusActive, AnnouncementStatusArchived:
return true
default:
return false
}
}

View File

@@ -0,0 +1,66 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAnnouncementTargeting_Matches_EmptyMatchesAll(t *testing.T) {
var targeting AnnouncementTargeting
require.True(t, targeting.Matches(0, nil))
require.True(t, targeting.Matches(123.45, map[int64]struct{}{1: {}}))
}
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsEmptyGroup(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{AllOf: nil},
},
}
_, err := targeting.NormalizeAndValidate()
require.Error(t, err)
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
}
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsInvalidCondition(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{
AllOf: []AnnouncementCondition{
{Type: "balance", Operator: "between", Value: 10},
},
},
},
}
_, err := targeting.NormalizeAndValidate()
require.Error(t, err)
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
}
func TestAnnouncementTargeting_Matches_AndOrSemantics(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{
AllOf: []AnnouncementCondition{
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorGTE, Value: 100},
{Type: AnnouncementConditionTypeSubscription, Operator: AnnouncementOperatorIn, GroupIDs: []int64{10}},
},
},
{
AllOf: []AnnouncementCondition{
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorLT, Value: 5},
},
},
},
}
// 命中第 2 组balance < 5
require.True(t, targeting.Matches(4.99, nil))
require.False(t, targeting.Matches(5, nil))
// 命中第 1 组balance >= 100 AND 订阅 in [10]
require.False(t, targeting.Matches(100, map[int64]struct{}{}))
require.False(t, targeting.Matches(99.9, map[int64]struct{}{10: {}}))
require.True(t, targeting.Matches(100, map[int64]struct{}{10: {}}))
}

View File

@@ -273,13 +273,11 @@ func logPrefix(sessionID, accountName string) string {
}
// Antigravity 直接支持的模型(精确匹配透传)
// 注意gemini-2.5 系列已移除,统一映射到 gemini-3 系列
var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true,
"claude-sonnet-4-5-thinking": true,
"gemini-2.5-flash": true,
"gemini-2.5-flash-lite": true,
"gemini-2.5-flash-thinking": true,
"gemini-3-flash": true,
"gemini-3-pro-low": true,
"gemini-3-pro-high": true,
@@ -288,23 +286,32 @@ var antigravitySupportedModels = map[string]bool{
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
// gemini-2.5 系列统一映射到 gemini-3 系列Antigravity 上游不再支持 2.5
var antigravityPrefixMapping = []struct {
prefix string
target string
}{
// 长前缀优先
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
// gemini-2.5 → gemini-3 映射(长前缀优先
{"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image
{"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash
{"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash
{"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high
{"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high
{"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high
// gemini-3 前缀映射
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
// Claude 映射
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
{"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
@@ -1530,7 +1537,11 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
func antigravityUseScopeRateLimit() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
return v == "1" || v == "true" || v == "yes" || v == "on"
// 默认开启按配额域限流,只有明确设置为禁用值时才关闭
if v == "0" || v == "false" || v == "no" || v == "off" {
return false
}
return true
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {

View File

@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-sonnet-4-5",
},
// 3. Gemini 透传
// 3. Gemini 2.5 → 3 映射
{
name: "Gemini透传 - gemini-2.5-flash",
name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
requestedModel: "gemini-2.5-flash",
accountMapping: nil,
expected: "gemini-2.5-flash",
expected: "gemini-3-flash",
},
{
name: "Gemini透传 - gemini-2.5-pro",
name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
requestedModel: "gemini-2.5-pro",
accountMapping: nil,
expected: "gemini-2.5-pro",
expected: "gemini-3-pro-high",
},
{
name: "Gemini透传 - gemini-future-model",

View File

@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email
}
// 获取 project_id部分账户类型可能没有
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
result.ProjectID = loadResp.CloudAICompanionProject
// 获取 project_id部分账户类型可能没有,失败时重试
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
result.ProjectIDMissing = true
} else {
result.ProjectID = projectID
}
return result, nil
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo.Email = existingEmail
}
// 每次刷新都调用 LoadCodeAssist 获取 project_id
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
// LoadCodeAssist 失败或返回空,保留原有 project_id标记缺失
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil {
// LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID
tokenInfo.ProjectIDMissing = true
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id本次只是临时故障不应标记为错误
if existingProjectID == "" {
tokenInfo.ProjectIDMissing = true
}
} else {
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// 指数退避1s, 2s, 4s
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 8*time.Second {
backoff = 8 * time.Second
}
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil
}
// 记录错误
if err != nil {
lastErr = err
} else if loadResp == nil {
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
} else {
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
}
}
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
// BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{

View File

@@ -89,3 +89,30 @@ func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *tim
}
return &resetAt
}
var antigravityAllScopes = []AntigravityQuotaScope{
AntigravityQuotaScopeClaude,
AntigravityQuotaScopeGeminiText,
AntigravityQuotaScopeGeminiImage,
}
func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
if a == nil || a.Platform != PlatformAntigravity {
return nil
}
now := time.Now()
result := make(map[string]int64)
for _, scope := range antigravityAllScopes {
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt != nil && now.Before(*resetAt) {
remainingSec := int64(time.Until(*resetAt).Seconds())
if remainingSec > 0 {
result[string(scope)] = remainingSec
}
}
}
if len(result) == 0 {
return nil
}
return result
}

View File

@@ -3,6 +3,8 @@ package service
import (
"context"
"fmt"
"log"
"strings"
"time"
)
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
}
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
// 合并旧的 credentials保留新 credentials 中不存在的字段
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的access_token 和 refresh_token 已更新)
if tokenInfo.ProjectIDMissing {
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id可能无法使用Antigravity")
if tokenInfo.ProjectID != "" {
// 有旧的 project_id本次获取失败保留旧值
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
} else {
// 从未获取过 project_id本次也失败但不返回错误以允许下次重试
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
}
}
return newCredentials, nil

View File

@@ -19,17 +19,19 @@ import (
)
var (
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
)
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
@@ -47,6 +49,7 @@ type JWTClaims struct {
// AuthService 认证服务
type AuthService struct {
userRepo UserRepository
redeemRepo RedeemCodeRepository
cfg *config.Config
settingService *SettingService
emailService *EmailService
@@ -58,6 +61,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例
func NewAuthService(
userRepo UserRepository,
redeemRepo RedeemCodeRepository,
cfg *config.Config,
settingService *SettingService,
emailService *EmailService,
@@ -67,6 +71,7 @@ func NewAuthService(
) *AuthService {
return &AuthService{
userRepo: userRepo,
redeemRepo: redeemRepo,
cfg: cfg,
settingService: settingService,
emailService: emailService,
@@ -78,11 +83,11 @@ func NewAuthService(
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "")
return s.RegisterWithVerification(ctx, email, password, "", "", "")
}
// RegisterWithVerification 用户注册(支持邮件验证优惠码返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
// RegisterWithVerification 用户注册(支持邮件验证优惠码和邀请码返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
// 检查是否开放注册默认关闭settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrEmailReserved
}
// 检查是否需要邀请码
var invitationRedeemCode *RedeemCode
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
if invitationCode == "" {
return "", nil, ErrInvitationCodeRequired
}
// 验证邀请码
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err)
return "", nil, ErrInvitationCodeInvalid
}
// 检查类型和状态
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
return "", nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
}
// 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
@@ -153,6 +178,14 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable
}
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
// 邀请码标记失败不影响注册,只记录日志
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
}
}
// 应用优惠码(如果提供且功能已启用)
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
@@ -580,3 +613,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 生成新token
return s.GenerateToken(user)
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证且 SMTP 配置正确
func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool {
if s.settingService == nil {
return false
}
// Must have email verification enabled and SMTP configured
if !s.settingService.IsEmailVerifyEnabled(ctx) {
return false
}
return s.settingService.IsPasswordResetEnabled(ctx)
}
// preparePasswordReset validates the password reset request and returns necessary data
// Returns (siteName, resetURL, shouldProceed)
// shouldProceed is false when we should silently return success (to prevent enumeration)
func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) {
// Check if user exists (but don't reveal this to the caller)
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// Security: Log but don't reveal that user doesn't exist
log.Printf("[Auth] Password reset requested for non-existent email: %s", email)
return "", "", false
}
log.Printf("[Auth] Database error checking email for password reset: %v", err)
return "", "", false
}
// Check if user is active
if !user.IsActive() {
log.Printf("[Auth] Password reset requested for inactive user: %s", email)
return "", "", false
}
// Get site name
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
// Build reset URL base
resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/"))
return siteName, resetURL, true
}
// RequestPasswordReset 请求密码重置(同步发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailService == nil {
return ErrServiceUnavailable
}
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
if !shouldProceed {
return nil // Silent success to prevent enumeration
}
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
log.Printf("[Auth] Password reset email sent to: %s", email)
return nil
}
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailQueueService == nil {
return ErrServiceUnavailable
}
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
if !shouldProceed {
return nil // Silent success to prevent enumeration
}
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
log.Printf("[Auth] Password reset email enqueued for: %s", email)
return nil
}
// ResetPassword 重置密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error {
// Check if password reset is enabled
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailService == nil {
return ErrServiceUnavailable
}
// Verify and consume the reset token (one-time use)
if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil {
return err
}
// Get user
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return ErrInvalidResetToken // Token was valid but user was deleted
}
log.Printf("[Auth] Database error getting user for password reset: %v", err)
return ErrServiceUnavailable
}
// Check if user is active
if !user.IsActive() {
return ErrUserNotActive
}
// Hash new password
hashedPassword, err := s.HashPassword(newPassword)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
// Update password and increment TokenVersion
user.PasswordHash = hashedPassword
user.TokenVersion++ // Invalidate all existing tokens
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err)
return ErrServiceUnavailable
}
log.Printf("[Auth] Password reset successful for user: %s", email)
return nil
}

View File

@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return nil
}
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
return false
}
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
return nil
}
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
@@ -95,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return NewAuthService(
repo,
nil, // redeemRepo
cfg,
settingService,
emailService,
@@ -132,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
@@ -144,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
@@ -158,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}

View File

@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
return s.CalculateCost(model, tokens, multiplier)
}
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
// threshold: 阈值(如 200000超过此值的部分按 extraMultiplier 倍计费
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
//
// 示例:缓存 210k + 输入 10k = 220k阈值 200k倍率 2.0
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
// 范围内正常计费,范围外 × 2 计费
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
// 未启用长上下文计费,直接走正常计费
if threshold <= 0 || extraMultiplier <= 1 {
return s.CalculateCost(model, tokens, rateMultiplier)
}
// 计算总输入 token缓存读取 + 新输入)
total := tokens.CacheReadTokens + tokens.InputTokens
if total <= threshold {
return s.CalculateCost(model, tokens, rateMultiplier)
}
// 拆分成范围内和范围外
var inRangeCacheTokens, inRangeInputTokens int
var outRangeCacheTokens, outRangeInputTokens int
if tokens.CacheReadTokens >= threshold {
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
inRangeCacheTokens = threshold
inRangeInputTokens = 0
outRangeCacheTokens = tokens.CacheReadTokens - threshold
outRangeInputTokens = tokens.InputTokens
} else {
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
inRangeCacheTokens = tokens.CacheReadTokens
inRangeInputTokens = threshold - tokens.CacheReadTokens
outRangeCacheTokens = 0
outRangeInputTokens = tokens.InputTokens - inRangeInputTokens
}
// 范围内部分:正常计费
inRangeTokens := UsageTokens{
InputTokens: inRangeInputTokens,
OutputTokens: tokens.OutputTokens, // 输出只算一次
CacheCreationTokens: tokens.CacheCreationTokens,
CacheReadTokens: inRangeCacheTokens,
}
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil {
return nil, err
}
// 范围外部分:× extraMultiplier 计费
outRangeTokens := UsageTokens{
InputTokens: outRangeInputTokens,
CacheReadTokens: outRangeCacheTokens,
}
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
if err != nil {
return inRangeCost, nil // 出错时返回范围内成本
}
// 合并成本
return &CostBreakdown{
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
OutputCost: inRangeCost.OutputCost,
CacheCreationCost: inRangeCost.CacheCreationCost,
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost,
}, nil
}
// ListSupportedModels 列出所有支持的模型现在总是返回true因为有模糊匹配
func (s *BillingService) ListSupportedModels() []string {
models := make([]string, 0)

View File

@@ -1,67 +1,70 @@
package service
import "github.com/Wei-Shaw/sub2api/internal/domain"
// Status constants
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
StatusActive = domain.StatusActive
StatusDisabled = domain.StatusDisabled
StatusError = domain.StatusError
StatusUnused = domain.StatusUnused
StatusUsed = domain.StatusUsed
StatusExpired = domain.StatusExpired
)
// Role constants
const (
RoleAdmin = "admin"
RoleUser = "user"
RoleAdmin = domain.RoleAdmin
RoleUser = domain.RoleUser
)
// Platform constants
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
PlatformSora = "sora"
PlatformAnthropic = domain.PlatformAnthropic
PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = domain.PlatformGemini
PlatformAntigravity = domain.PlatformAntigravity
PlatformSora = domain.PlatformSora
)
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
)
// Redeem type constants
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
RedeemTypeBalance = domain.RedeemTypeBalance
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
RedeemTypeSubscription = domain.RedeemTypeSubscription
RedeemTypeInvitation = domain.RedeemTypeInvitation
)
// PromoCode status constants
const (
PromoCodeStatusActive = "active"
PromoCodeStatusDisabled = "disabled"
PromoCodeStatusActive = domain.PromoCodeStatusActive
PromoCodeStatusDisabled = domain.PromoCodeStatusDisabled
)
// Admin adjustment type constants
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
AdjustmentTypeAdminBalance = domain.AdjustmentTypeAdminBalance // 管理员调整余额
AdjustmentTypeAdminConcurrency = domain.AdjustmentTypeAdminConcurrency // 管理员调整并发数
)
// Group subscription type constants
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
SubscriptionTypeStandard = domain.SubscriptionTypeStandard // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = domain.SubscriptionTypeSubscription // 订阅模式(按限额控制)
)
// Subscription status constants
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
SubscriptionStatusActive = domain.SubscriptionStatusActive
SubscriptionStatusExpired = domain.SubscriptionStatusExpired
SubscriptionStatusSuspended = domain.SubscriptionStatusSuspended
)
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀RFC 保留域名)。
@@ -70,9 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
@@ -88,6 +93,9 @@ const (
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// TOTP 双因素认证设置
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
// LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
@@ -95,14 +103,16 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML或 URL 作为 iframe src
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML或 URL 作为 iframe src
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL作为 iframe src
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量

View File

@@ -8,11 +8,18 @@ import (
"time"
)
// Task type constants
const (
TaskTypeVerifyCode = "verify_code"
TaskTypePasswordReset = "password_reset"
)
// EmailTask 邮件发送任务
type EmailTask struct {
Email string
SiteName string
TaskType string // "verify_code"
TaskType string // "verify_code" or "password_reset"
ResetURL string // Only used for password_reset task type
}
// EmailQueueService 异步邮件队列服务
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
defer cancel()
switch task.TaskType {
case "verify_code":
case TaskTypeVerifyCode:
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
}
case TaskTypePasswordReset:
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
}
default:
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
}
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: "verify_code",
TaskType: TaskTypeVerifyCode,
}
select {
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
}
}
// EnqueuePasswordReset 将密码重置邮件任务加入队列
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: TaskTypePasswordReset,
ResetURL: resetURL,
}
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
}
}
// Stop 停止队列服务
func (s *EmailQueueService) Stop() {
close(s.stopChan)

View File

@@ -3,11 +3,14 @@ package service
import (
"context"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"math/big"
"net/smtp"
"net/url"
"strconv"
"time"
@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
// Password reset errors
ErrInvalidResetToken = infraerrors.BadRequest("INVALID_RESET_TOKEN", "invalid or expired password reset token")
)
// EmailCache defines cache operations for email service
@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
// Password reset token methods
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
DeletePasswordResetToken(ctx context.Context, email string) error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
}
// VerificationCodeData represents verification code data
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt time.Time
}
// PasswordResetTokenData represents password reset token data
type PasswordResetTokenData struct {
Token string
CreatedAt time.Time
}
const (
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5
// Password reset token settings
passwordResetTokenTTL = 30 * time.Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown = 30 * time.Second
)
// SMTPConfig SMTP配置
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return ErrVerifyCodeMaxAttempts
}
// 验证码不匹配
if data.Code != code {
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return client.Quit()
}
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func (s *EmailService) GeneratePasswordResetToken() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
var token string
var needSaveToken bool
// Check if token already exists
existing, err := s.cache.GetPasswordResetToken(ctx, email)
if err == nil && existing != nil {
// Token exists, reuse it (allows resending email without generating new token)
token = existing.Token
needSaveToken = false
} else {
// Generate new token
token, err = s.GeneratePasswordResetToken()
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
needSaveToken = true
}
// Save token to Redis (only if new token generated)
if needSaveToken {
data := &PasswordResetTokenData{
Token: token,
CreatedAt: time.Now(),
}
if err := s.cache.SetPasswordResetToken(ctx, email, data, passwordResetTokenTTL); err != nil {
return fmt.Errorf("save reset token: %w", err)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
// Build email content
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
// Send email
if err := s.SendEmail(ctx, email, subject, body); err != nil {
return fmt.Errorf("send email: %w", err)
}
return nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
// Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
return nil // Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
return err
}
// Set cooldown marker (Redis TTL handles expiration)
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
}
return nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func (s *EmailService) VerifyPasswordResetToken(ctx context.Context, email, token string) error {
data, err := s.cache.GetPasswordResetToken(ctx, email)
if err != nil || data == nil {
return ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if subtle.ConstantTimeCompare([]byte(data.Token), []byte(token)) != 1 {
return ErrInvalidResetToken
}
return nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, token string) error {
// Verify first
if err := s.VerifyPasswordResetToken(ctx, email, token); err != nil {
return err
}
// Delete after verification (one-time use)
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
}
return nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func (s *EmailService) buildPasswordResetEmailBody(resetURL, siteName string) string {
return fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`, siteName, resetURL, resetURL)
}

View File

@@ -0,0 +1,23 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMergeAnthropicBeta(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"foo, oauth-2025-04-20,bar, foo",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got)
}
func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
}

View File

@@ -269,6 +269,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
return 0, nil
}
func (m *mockGroupRepoForGateway) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func ptr[T any](v T) *T {
return &v
}

View File

@@ -0,0 +1,62 @@
package service
import (
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
System: nil,
Messages: nil,
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id
}
fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format
got := svc.buildOAuthMetadataUserID(parsed, account, fp)
require.NotEmpty(t, got)
// Legacy format: user_{client}_account__session_{uuid}
re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}
func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{
"account_uuid": "acc-uuid",
"claude_user_id": "clientid123",
"anthropic_user_id": "",
},
}
got := svc.buildOAuthMetadataUserID(parsed, account, nil)
require.NotEmpty(t, got)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}

View File

@@ -2,6 +2,7 @@ package service
import (
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
func TestInjectClaudeCodePrompt(t *testing.T) {
claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
tests := []struct {
name string
body string
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system: "Custom prompt",
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom prompt",
wantSecondText: claudePrefix + "\n\nCustom prompt",
},
{
name: "string system equals Claude Code prompt",
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom",
wantSecondText: claudePrefix + "\n\nCustom",
},
{
name: "array system with existing Claude Code prompt (should dedupe)",
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Other",
wantSecondText: claudePrefix + "\n\nOther",
},
{
name: "empty array",

View File

@@ -0,0 +1,21 @@
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
in := "You are OpenCode, the best coding agent on the planet."
got := sanitizeSystemText(in)
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
}
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
in := "OpenCode and opencode are mentioned."
got := sanitizeToolDescription(in)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require.Equal(t, in, got)
}

File diff suppressed because it is too large Load Diff

View File

@@ -36,6 +36,11 @@ const (
geminiRetryMaxDelay = 16 * time.Second
)
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
const geminiDummyThoughtSignature = "skip_thought_signature_validator"
type GeminiMessagesCompatService struct {
accountRepo AccountRepository
groupRepo GroupRepository
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq)
originalClaudeBody := body
proxyURL := ""
@@ -931,6 +937,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
}
// 图片生成计费
imageCount := 0
imageSize := s.extractImageSize(body)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
@@ -938,6 +951,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
@@ -969,6 +984,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
body = ensureGeminiFunctionCallThoughtSignatures(body)
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
@@ -1371,6 +1390,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
usage = &ClaudeUsage{}
}
// 图片生成计费
imageCount := 0
imageSize := s.extractImageSize(body)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
@@ -1378,6 +1404,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
@@ -2504,9 +2532,13 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
}
prompt, _ := asInt(usageMeta["promptTokenCount"])
cand, _ := asInt(usageMeta["candidatesTokenCount"])
cached, _ := asInt(usageMeta["cachedContentTokenCount"])
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
return &ClaudeUsage{
InputTokens: prompt,
OutputTokens: cand,
InputTokens: prompt - cached,
OutputTokens: cand,
CacheReadInputTokens: cached,
}
}
@@ -2635,6 +2667,58 @@ func nextGeminiDailyResetUnix() *int64 {
return &ts
}
func ensureGeminiFunctionCallThoughtSignatures(body []byte) []byte {
// Fast path: only run when functionCall is present.
if !bytes.Contains(body, []byte(`"functionCall"`)) {
return body
}
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return body
}
contentsAny, ok := payload["contents"].([]any)
if !ok || len(contentsAny) == 0 {
return body
}
modified := false
for _, c := range contentsAny {
cm, ok := c.(map[string]any)
if !ok {
continue
}
partsAny, ok := cm["parts"].([]any)
if !ok || len(partsAny) == 0 {
continue
}
for _, p := range partsAny {
pm, ok := p.(map[string]any)
if !ok || pm == nil {
continue
}
if fc, ok := pm["functionCall"].(map[string]any); !ok || fc == nil {
continue
}
ts, _ := pm["thoughtSignature"].(string)
if strings.TrimSpace(ts) == "" {
pm["thoughtSignature"] = geminiDummyThoughtSignature
modified = true
}
}
}
if !modified {
return body
}
b, err := json.Marshal(payload)
if err != nil {
return body
}
return b
}
func extractGeminiFinishReason(geminiResp map[string]any) string {
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
@@ -2834,7 +2918,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
toolUseIDToName[id] = name
}
signature, _ := bm["signature"].(string)
signature = strings.TrimSpace(signature)
if signature == "" {
signature = geminiDummyThoughtSignature
}
parts = append(parts, map[string]any{
"thoughtSignature": signature,
"functionCall": map[string]any{
"name": name,
"args": bm["input"],
@@ -3031,3 +3121,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
}
return out
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string {
var req struct {
GenerationConfig *struct {
ImageConfig *struct {
ImageSize string `json:"imageSize"`
} `json:"imageConfig"`
} `json:"generationConfig"`
}
if err := json.Unmarshal(body, &req); err != nil {
return "2K"
}
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
if size == "1K" || size == "2K" || size == "4K" {
return size
}
}
return "2K"
}

View File

@@ -1,6 +1,8 @@
package service
import (
"encoding/json"
"strings"
"testing"
)
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
})
}
}
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001",
"max_tokens": 10,
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "hi"},
},
},
map[string]any{
"role": "assistant",
"content": []any{
map[string]any{"type": "text", "text": "ok"},
map[string]any{
"type": "tool_use",
"id": "toolu_123",
"name": "default_api:write_file",
"input": map[string]any{"path": "a.txt", "content": "x"},
// no signature on purpose
},
},
},
},
"tools": []any{
map[string]any{
"name": "default_api:write_file",
"description": "write file",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{"path": map[string]any{"type": "string"}},
},
},
},
}
b, _ := json.Marshal(claudeReq)
out, err := convertClaudeMessagesToGeminiGenerateContent(b)
if err != nil {
t.Fatalf("convert failed: %v", err)
}
s := string(out)
if !strings.Contains(s, "\"functionCall\"") {
t.Fatalf("expected functionCall in output, got: %s", s)
}
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
}
}
func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing.T) {
geminiReq := map[string]any{
"contents": []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{
"functionCall": map[string]any{
"name": "default_api:write_file",
"args": map[string]any{"path": "a.txt"},
},
},
},
},
},
}
b, _ := json.Marshal(geminiReq)
out := ensureGeminiFunctionCallThoughtSignatures(b)
s := string(out)
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
}
}

View File

@@ -221,6 +221,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
return 0, nil
}
func (m *mockGroupRepoForGemini) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock

View File

@@ -0,0 +1,72 @@
package service
import (
"encoding/json"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
//
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
// to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures.
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
if len(body) == 0 {
return body
}
// 解析 JSON
var data any
if err := json.Unmarshal(body, &data); err != nil {
// 如果解析失败,返回原始 body可能不是 JSON 或格式不正确)
return body
}
// 递归清理 thoughtSignature
cleaned := cleanThoughtSignaturesRecursive(data)
// 重新序列化
result, err := json.Marshal(cleaned)
if err != nil {
// 如果序列化失败,返回原始 body
return body
}
return result
}
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
func cleanThoughtSignaturesRecursive(data any) any {
switch v := data.(type) {
case map[string]any:
// 创建新的 map移除 thoughtSignature
result := make(map[string]any, len(v))
for key, value := range v {
// 跳过 thoughtSignature 字段
if key == "thoughtSignature" {
continue
}
// 递归处理嵌套结构
result[key] = cleanThoughtSignaturesRecursive(value)
}
return result
case []any:
// 递归处理数组中的每个元素
result := make([]any, len(v))
for i, item := range v {
result[i] = cleanThoughtSignaturesRecursive(item)
}
return result
default:
// 基本类型string, number, bool, null直接返回
return v
}
}

View File

@@ -29,6 +29,10 @@ type GroupRepository interface {
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID去重
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
// BindAccountsToGroup 将多个账号绑定到指定分组
BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error
}
// CreateGroupRequest 创建分组请求

View File

@@ -26,13 +26,13 @@ var (
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.0.62 (external, cli)",
UserAgent: "claude-cli/2.1.22 (external, cli)",
StainlessLang: "js",
StainlessPackageVersion: "0.52.0",
StainlessPackageVersion: "0.70.0",
StainlessOS: "Linux",
StainlessArch: "x64",
StainlessArch: "arm64",
StainlessRuntime: "node",
StainlessRuntimeVersion: "v22.14.0",
StainlessRuntimeVersion: "v24.13.0",
}
// Fingerprint represents account fingerprint data
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
}
// parseUserAgentVersion 解析user-agent版本号
// 例如claude-cli/2.0.62 -> (2, 0, 62)
// 例如claude-cli/2.1.2 -> (2, 1, 2)
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
// 匹配 xxx/x.y.z 格式
matches := userAgentVersionRegex.FindStringSubmatch(ua)

View File

@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
UpdatedAt string `json:"updated_at,omitempty"`
}
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
type NormalizedCodexLimits struct {
Used5hPercent *float64
Reset5hSeconds *int
Window5hMinutes *int
Used7dPercent *float64
Reset7dSeconds *int
Window7dMinutes *int
}
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
// Returns nil if snapshot is nil or has no useful data.
func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits {
if s == nil {
return nil
}
result := &NormalizedCodexLimits{}
primaryMins := 0
secondaryMins := 0
hasPrimaryWindow := false
hasSecondaryWindow := false
if s.PrimaryWindowMinutes != nil {
primaryMins = *s.PrimaryWindowMinutes
hasPrimaryWindow = true
}
if s.SecondaryWindowMinutes != nil {
secondaryMins = *s.SecondaryWindowMinutes
hasSecondaryWindow = true
}
// Determine mapping based on window_minutes
use5hFromPrimary := false
use7dFromPrimary := false
if hasPrimaryWindow && hasSecondaryWindow {
// Both known: smaller window is 5h, larger is 7d
if primaryMins < secondaryMins {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
}
} else if hasPrimaryWindow {
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
if primaryMins <= 360 {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
}
} else if hasSecondaryWindow {
// Only secondary known: classify by threshold
if secondaryMins <= 360 {
// 5h from secondary, so primary (if any data) is 7d
use7dFromPrimary = true
} else {
// 7d from secondary, so primary (if any data) is 5h
use5hFromPrimary = true
}
} else {
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
use7dFromPrimary = true
}
// Assign values
if use5hFromPrimary {
result.Used5hPercent = s.PrimaryUsedPercent
result.Reset5hSeconds = s.PrimaryResetAfterSeconds
result.Window5hMinutes = s.PrimaryWindowMinutes
result.Used7dPercent = s.SecondaryUsedPercent
result.Reset7dSeconds = s.SecondaryResetAfterSeconds
result.Window7dMinutes = s.SecondaryWindowMinutes
} else if use7dFromPrimary {
result.Used7dPercent = s.PrimaryUsedPercent
result.Reset7dSeconds = s.PrimaryResetAfterSeconds
result.Window7dMinutes = s.PrimaryWindowMinutes
result.Used5hPercent = s.SecondaryUsedPercent
result.Reset5hSeconds = s.SecondaryResetAfterSeconds
result.Window5hMinutes = s.SecondaryWindowMinutes
}
return result
}
// OpenAIUsage represents OpenAI API response usage
type OpenAIUsage struct {
InputTokens int `json:"input_tokens"`
@@ -70,12 +156,15 @@ type OpenAIUsage struct {
// OpenAIForwardResult represents the result of forwarding
type OpenAIForwardResult struct {
RequestID string
Usage OpenAIUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int
RequestID string
Usage OpenAIUsage
Model string
// ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
// Stored for usage records display; nil means not provided / not applicable.
ReasoningEffort *string
Stream bool
Duration time.Duration
FirstTokenMs *int
}
// OpenAIGatewayService handles OpenAI API gateway operations
@@ -756,6 +845,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified = true
}
}
// Remove prompt_cache_retention (not supported by upstream OpenAI API)
if _, has := reqBody["prompt_cache_retention"]; has {
delete(reqBody, "prompt_cache_retention")
bodyModified = true
}
}
// Re-serialize body only if modified
@@ -867,18 +962,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == AccountTypeOAuth {
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
@@ -1174,15 +1272,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
// 仅发送一次错误事件,避免多次写入导致协议混乱
// 注意OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema
// 否则下游 SDK例如 OpenCode会因为类型校验失败而报错。
errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sendErrorEvent := func(reason string) {
if errorEventSent {
if errorEventSent || clientDisconnected {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
payload := map[string]any{
"type": "error",
"sequence_number": 0,
"error": map[string]any{
"type": "upstream_error",
"message": reason,
"code": reason,
},
}
if b, err := json.Marshal(payload); err == nil {
_, _ = fmt.Fprintf(w, "data: %s\n\n", b)
flusher.Flush()
}
}
needModelReplace := originalModel != mappedModel
@@ -1194,6 +1306,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event避免下游 SDK 解析失败。
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
@@ -1217,15 +1340,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
data = correctedData
line = "data: " + correctedData
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
// 写入客户端(客户端断开后继续 drain 上游)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
}
flusher.Flush()
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
@@ -1235,11 +1362,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
}
flusher.Flush()
}
case <-intervalCh:
@@ -1247,6 +1377,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
log.Printf("Upstream timeout after client disconnect, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil {
@@ -1256,11 +1390,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
continue
}
flusher.Flush()
}
@@ -1601,6 +1740,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
ReasoningEffort: result.ReasoningEffort,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
@@ -1665,8 +1805,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
return nil
}
// extractCodexUsageHeaders extracts Codex usage limits from response headers
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
snapshot := &OpenAICodexUsageSnapshot{}
hasData := false
@@ -1740,6 +1881,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
// Convert snapshot to map for merging into Extra
updates := make(map[string]any)
// Save raw primary/secondary fields for debugging/tracing
if snapshot.PrimaryUsedPercent != nil {
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
}
@@ -1763,109 +1906,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
}
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
// Normalize to canonical 5h/7d fields based on window_minutes
// This fixes the issue where OpenAI's primary/secondary naming is reversed
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
// IMPORTANT: We can only reliably determine window type from window_minutes field
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
var primaryWindowMins, secondaryWindowMins int
var hasPrimaryWindow, hasSecondaryWindow bool
// Only use window_minutes for reliable window size comparison
if snapshot.PrimaryWindowMinutes != nil {
primaryWindowMins = *snapshot.PrimaryWindowMinutes
hasPrimaryWindow = true
}
if snapshot.SecondaryWindowMinutes != nil {
secondaryWindowMins = *snapshot.SecondaryWindowMinutes
hasSecondaryWindow = true
}
// Determine which is 5h and which is 7d
var use5hFromPrimary, use7dFromPrimary bool
var use5hFromSecondary, use7dFromSecondary bool
if hasPrimaryWindow && hasSecondaryWindow {
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
if primaryWindowMins < secondaryWindowMins {
use5hFromPrimary = true
use7dFromSecondary = true
} else {
use5hFromSecondary = true
use7dFromPrimary = true
// Normalize to canonical 5h/7d fields
if normalized := snapshot.Normalize(); normalized != nil {
if normalized.Used5hPercent != nil {
updates["codex_5h_used_percent"] = *normalized.Used5hPercent
}
} else if hasPrimaryWindow {
// Only primary window size known: classify by absolute threshold
if primaryWindowMins <= 360 {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
if normalized.Reset5hSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds
}
} else if hasSecondaryWindow {
// Only secondary window size known: classify by absolute threshold
if secondaryWindowMins <= 360 {
use5hFromSecondary = true
} else {
use7dFromSecondary = true
if normalized.Window5hMinutes != nil {
updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes
}
} else {
// No window_minutes available: cannot reliably determine window types
// Fall back to legacy assumption (may be incorrect)
// Assume primary=7d, secondary=5h based on historical observation
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
use5hFromSecondary = true
if normalized.Used7dPercent != nil {
updates["codex_7d_used_percent"] = *normalized.Used7dPercent
}
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
use7dFromPrimary = true
if normalized.Reset7dSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds
}
}
// Write canonical 5h fields
if use5hFromPrimary {
if snapshot.PrimaryUsedPercent != nil {
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent
}
if snapshot.PrimaryResetAfterSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
} else if use5hFromSecondary {
if snapshot.SecondaryUsedPercent != nil {
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
}
}
// Write canonical 7d fields
if use7dFromPrimary {
if snapshot.PrimaryUsedPercent != nil {
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
}
if snapshot.PrimaryResetAfterSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
} else if use7dFromSecondary {
if snapshot.SecondaryUsedPercent != nil {
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
if normalized.Window7dMinutes != nil {
updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes
}
}
@@ -1876,3 +1935,86 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}()
}
func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) {
if reqBody == nil {
return "", false
}
// Primary: reasoning.effort
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
if effort, ok := reasoning["effort"].(string); ok {
return normalizeOpenAIReasoningEffort(effort), true
}
}
// Fallback: some clients may use a flat field.
if effort, ok := reqBody["reasoning_effort"].(string); ok {
return normalizeOpenAIReasoningEffort(effort), true
}
return "", false
}
func deriveOpenAIReasoningEffortFromModel(model string) string {
if strings.TrimSpace(model) == "" {
return ""
}
modelID := strings.TrimSpace(model)
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
switch r {
case '-', '_', ' ':
return true
default:
return false
}
})
if len(parts) == 0 {
return ""
}
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
}
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
if value == "" {
return nil
}
return &value
}
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
if value == "" {
return nil
}
return &value
}
func normalizeOpenAIReasoningEffort(raw string) string {
value := strings.ToLower(strings.TrimSpace(raw))
if value == "" {
return ""
}
// Normalize separators for "x-high"/"x_high" variants.
value = strings.NewReplacer("-", "", "_", "", " ", "").Replace(value)
switch value {
case "none", "minimal":
return ""
case "low", "medium", "high":
return value
case "xhigh", "extrahigh":
return "xhigh"
default:
// Only store known effort levels for now to keep UI consistent.
return ""
}
}

View File

@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
skipDefaultLoad bool
}
type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *failingGinWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if c.acquireResults != nil {
if result, ok := c.acquireResults[accountID]; ok {
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
t.Fatalf("expected stream timeout error, got %v", err)
}
if !strings.Contains(rec.Body.String(), "stream_timeout") {
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") {
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: cancelReadCloser{},
Header: http.Header{},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if result == nil || result.usage == nil {
t.Fatalf("expected usage result")
}
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
t.Fatalf("unexpected usage: %+v", *result.usage)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
}
}
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
if !errors.Is(err, bufio.ErrTooLong) {
t.Fatalf("expected ErrTooLong, got %v", err)
}
if !strings.Contains(rec.Body.String(), "response_too_large") {
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") {
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
}
}

View File

@@ -2,9 +2,10 @@ package service
import (
"context"
"fmt"
"net/http"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
@@ -35,12 +36,12 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// Generate PKCE values
state, err := openai.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_STATE_FAILED", "failed to generate state: %v", err)
}
codeVerifier, err := openai.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_VERIFIER_FAILED", "failed to generate code verifier: %v", err)
}
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
@@ -48,14 +49,17 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// Generate session ID
sessionID, err := openai.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_SESSION_FAILED", "failed to generate session ID: %v", err)
}
// Get proxy URL if specified
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
if err != nil {
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy != nil {
proxyURL = proxy.URL()
}
}
@@ -110,14 +114,17 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Get session
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
}
// Get proxy URL
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
if err != nil {
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy != nil {
proxyURL = proxy.URL()
}
}
@@ -131,7 +138,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Exchange code for token
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
return nil, err
}
// Parse ID token to get user info
@@ -201,12 +208,12 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
// RefreshAccountToken refreshes token for an OpenAI account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() {
return nil, fmt.Errorf("account is not an OpenAI account")
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
}
refreshToken := account.GetOpenAIRefreshToken()
if refreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
}
var proxyURL string

View File

@@ -67,6 +67,8 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
if acc.Platform != "" {
if _, ok := platform[acc.Platform]; !ok {
platform[acc.Platform] = &PlatformAvailability{
@@ -84,6 +86,14 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError {
p.ErrorCount++
}
if len(scopeRateLimits) > 0 {
if p.ScopeRateLimitCount == nil {
p.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
p.ScopeRateLimitCount[scope]++
}
}
}
for _, grp := range acc.Groups {
@@ -108,6 +118,14 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError {
g.ErrorCount++
}
if len(scopeRateLimits) > 0 {
if g.ScopeRateLimitCount == nil {
g.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
g.ScopeRateLimitCount[scope]++
}
}
}
displayGroupID := int64(0)
@@ -140,6 +158,9 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
item.RateLimitRemainingSec = &remainingSec
}
}
if len(scopeRateLimits) > 0 {
item.ScopeRateLimits = scopeRateLimits
}
if isOverloaded && acc.OverloadUntil != nil {
item.OverloadUntil = acc.OverloadUntil
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())

View File

@@ -39,22 +39,24 @@ type AccountConcurrencyInfo struct {
// PlatformAvailability aggregates account availability by platform.
type PlatformAvailability struct {
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
ErrorCount int64 `json:"error_count"`
}
// GroupAvailability aggregates account availability by group.
type GroupAvailability struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"`
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
ErrorCount int64 `json:"error_count"`
}
// AccountAvailability represents current availability for a single account.
@@ -72,10 +74,11 @@ type AccountAvailability struct {
IsOverloaded bool `json:"is_overloaded"`
HasError bool `json:"has_error"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
OverloadUntil *time.Time `json:"overload_until"`
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
ErrorMessage string `json:"error_message"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"`
OverloadUntil *time.Time `json:"overload_until"`
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
ErrorMessage string `json:"error_message"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
}

View File

@@ -83,6 +83,7 @@ type OpsAdvancedSettings struct {
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
}

View File

@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
// 解析重置时间戳
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded
if account.Platform == PlatformOpenAI {
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("openai_account_rate_limited", "account_id", account.ID, "reset_at", *resetAt)
return
}
}
// 2. 尝试从响应头解析重置时间Anthropic
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
// 3. 如果响应头没有尝试从响应体解析OpenAI usage_limit_reached, Gemini
if resetTimestamp == "" {
switch account.Platform {
case PlatformOpenAI:
// 尝试解析 OpenAI 的 usage_limit_reached 错误
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
return
}
case PlatformGemini, PlatformAntigravity:
// 尝试解析 Gemini 格式(用于其他平台)
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
return
}
}
// 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
}
return
}
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
return strings.Contains(msg, "sonnet")
}
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
snapshot := ParseCodexRateLimitHeaders(headers)
if snapshot == nil {
return nil
}
normalized := snapshot.Normalize()
if normalized == nil {
return nil
}
now := time.Now()
// 判断哪个限制被触发used_percent >= 100
is7dExhausted := normalized.Used7dPercent != nil && *normalized.Used7dPercent >= 100
is5hExhausted := normalized.Used5hPercent != nil && *normalized.Used5hPercent >= 100
// 优先使用被触发限制的重置时间
if is7dExhausted && normalized.Reset7dSeconds != nil {
resetAt := now.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
slog.Info("openai_429_7d_limit_exhausted", "reset_after_seconds", *normalized.Reset7dSeconds, "reset_at", resetAt)
return &resetAt
}
if is5hExhausted && normalized.Reset5hSeconds != nil {
resetAt := now.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
slog.Info("openai_429_5h_limit_exhausted", "reset_after_seconds", *normalized.Reset5hSeconds, "reset_at", resetAt)
return &resetAt
}
// 都未达到100%但收到429使用较长的重置时间
var maxResetSecs int
if normalized.Reset7dSeconds != nil && *normalized.Reset7dSeconds > maxResetSecs {
maxResetSecs = *normalized.Reset7dSeconds
}
if normalized.Reset5hSeconds != nil && *normalized.Reset5hSeconds > maxResetSecs {
maxResetSecs = *normalized.Reset5hSeconds
}
if maxResetSecs > 0 {
resetAt := now.Add(time.Duration(maxResetSecs) * time.Second)
slog.Info("openai_429_using_max_reset", "max_reset_seconds", maxResetSecs, "reset_at", resetAt)
return &resetAt
}
return nil
}
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式:
//
// {
// "error": {
// "message": "The usage limit has been reached",
// "type": "usage_limit_reached",
// "resets_at": 1769404154,
// "resets_in_seconds": 133107
// }
// }
func parseOpenAIRateLimitResetTime(body []byte) *int64 {
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err != nil {
return nil
}
errObj, ok := parsed["error"].(map[string]any)
if !ok {
return nil
}
// 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型
errType, _ := errObj["type"].(string)
if errType != "usage_limit_reached" && errType != "rate_limit_exceeded" {
return nil
}
// 优先使用 resets_atUnix 时间戳)
if resetsAt, ok := errObj["resets_at"].(float64); ok {
ts := int64(resetsAt)
return &ts
}
if resetsAt, ok := errObj["resets_at"].(string); ok {
if ts, err := strconv.ParseInt(resetsAt, 10, 64); err == nil {
return &ts
}
}
// 如果没有 resets_at尝试使用 resets_in_seconds
if resetsInSeconds, ok := errObj["resets_in_seconds"].(float64); ok {
ts := time.Now().Unix() + int64(resetsInSeconds)
return &ts
}
if resetsInSeconds, ok := errObj["resets_in_seconds"].(string); ok {
if sec, err := strconv.ParseInt(resetsInSeconds, 10, 64); err == nil {
ts := time.Now().Unix() + sec
return &ts
}
}
return nil
}
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {

View File

@@ -0,0 +1,364 @@
package service
import (
"net/http"
"testing"
"time"
)
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
svc := &RateLimitService{}
// Simulate headers when 7d limit is exhausted (100% used)
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "384607") // ~4.5 days
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
headers.Set("x-codex-secondary-used-percent", "3")
headers.Set("x-codex-secondary-reset-after-seconds", "17369") // ~4.8 hours
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should be approximately 384607 seconds from now
expectedDuration := 384607 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_5hExhausted(t *testing.T) {
svc := &RateLimitService{}
// Simulate headers when 5h limit is exhausted (100% used)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "50")
headers.Set("x-codex-primary-reset-after-seconds", "500000")
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "3600") // 1 hour
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should be approximately 3600 seconds from now
expectedDuration := 3600 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax(t *testing.T) {
svc := &RateLimitService{}
// Neither limit at 100%, should use the longer reset time
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "80")
headers.Set("x-codex-primary-reset-after-seconds", "100000")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "90")
headers.Set("x-codex-secondary-reset-after-seconds", "5000")
headers.Set("x-codex-secondary-window-minutes", "300")
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should use the max (100000 seconds from 7d window)
expectedDuration := 100000 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_NoCodexHeaders(t *testing.T) {
svc := &RateLimitService{}
// No codex headers at all
headers := http.Header{}
headers.Set("content-type", "application/json")
resetAt := svc.calculateOpenAI429ResetTime(headers)
if resetAt != nil {
t.Errorf("expected nil resetAt when no codex headers, got %v", resetAt)
}
}
func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
svc := &RateLimitService{}
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100") // This is 5h
headers.Set("x-codex-primary-reset-after-seconds", "3600") // 1 hour
headers.Set("x-codex-primary-window-minutes", "300") // 5 hours - smaller!
headers.Set("x-codex-secondary-used-percent", "50")
headers.Set("x-codex-secondary-reset-after-seconds", "500000")
headers.Set("x-codex-secondary-window-minutes", "10080") // 7 days - larger!
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should correctly identify that primary is 5h (smaller window) and use its reset time
expectedDuration := 3600 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestNormalizedCodexLimits(t *testing.T) {
// Test the Normalize() method directly
pUsed := 100.0
pReset := 384607
pWindow := 10080
sUsed := 3.0
sReset := 17369
sWindow := 300
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
PrimaryWindowMinutes: &pWindow,
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
SecondaryWindowMinutes: &sWindow,
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Primary has larger window (10080 > 300), so primary should be 7d
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 384607 {
t.Errorf("expected Reset7dSeconds=384607, got %v", normalized.Reset7dSeconds)
}
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 3.0 {
t.Errorf("expected Used5hPercent=3, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 17369 {
t.Errorf("expected Reset5hSeconds=17369, got %v", normalized.Reset5hSeconds)
}
}
func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
// Test when only primary has data, no window_minutes
pUsed := 80.0
pReset := 50000
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
// No window_minutes, no secondary data
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 80.0 {
t.Errorf("expected Used7dPercent=80, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 50000 {
t.Errorf("expected Reset7dSeconds=50000, got %v", normalized.Reset7dSeconds)
}
// Secondary (5h) should be nil
if normalized.Used5hPercent != nil {
t.Errorf("expected Used5hPercent=nil, got %v", *normalized.Used5hPercent)
}
if normalized.Reset5hSeconds != nil {
t.Errorf("expected Reset5hSeconds=nil, got %v", *normalized.Reset5hSeconds)
}
}
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
// Test when only secondary has data, no window_minutes
sUsed := 60.0
sReset := 3000
snapshot := &OpenAICodexUsageSnapshot{
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
// No window_minutes, no primary data
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
// So secondary goes to 5h
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 60.0 {
t.Errorf("expected Used5hPercent=60, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 3000 {
t.Errorf("expected Reset5hSeconds=3000, got %v", normalized.Reset5hSeconds)
}
// Primary (7d) should be nil
if normalized.Used7dPercent != nil {
t.Errorf("expected Used7dPercent=nil, got %v", *normalized.Used7dPercent)
}
}
func TestNormalizedCodexLimits_BothDataNoWindowMinutes(t *testing.T) {
// Test when both have data but no window_minutes
pUsed := 100.0
pReset := 400000
sUsed := 50.0
sReset := 10000
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
// No window_minutes
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 400000 {
t.Errorf("expected Reset7dSeconds=400000, got %v", normalized.Reset7dSeconds)
}
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 50.0 {
t.Errorf("expected Used5hPercent=50, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 10000 {
t.Errorf("expected Reset5hSeconds=10000, got %v", normalized.Reset5hSeconds)
}
}
func TestHandle429_AnthropicPlatformUnaffected(t *testing.T) {
// Verify that Anthropic platform accounts still use the original logic
// This test ensures we don't break existing Claude account rate limiting
svc := &RateLimitService{}
// Simulate Anthropic 429 headers
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-reset", "1737820800") // A future Unix timestamp
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
// because it only handles OpenAI platform
resetAt := svc.calculateOpenAI429ResetTime(headers)
// Should return nil since there are no x-codex-* headers
if resetAt != nil {
t.Errorf("expected nil for Anthropic headers, got %v", resetAt)
}
}
func TestCalculateOpenAI429ResetTime_UserProvidedScenario(t *testing.T) {
// This is the exact scenario from the user:
// codex_7d_used_percent: 100
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
// codex_5h_used_percent: 3
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
svc := &RateLimitService{}
// Simulate headers matching user's data
// Note: We need to map the canonical 5h/7d back to primary/secondary
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "384607")
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days = 10080 minutes
headers.Set("x-codex-secondary-used-percent", "3")
headers.Set("x-codex-secondary-reset-after-seconds", "17369")
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours = 300 minutes
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt for user scenario")
}
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
expectedDuration := 384607 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
// Verify it's approximately 4.45 days (384607 seconds)
duration := resetAt.Sub(before)
actualDays := duration.Hours() / 24.0
// 384607 / 86400 = ~4.45 days
if actualDays < 4.4 || actualDays > 4.5 {
t.Errorf("expected ~4.45 days, got %.2f days", actualDays)
}
t.Logf("User scenario: reset_at=%v, duration=%.2f days", resetAt, actualDays)
}
func TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset(t *testing.T) {
// Test that we return nil when there's used_percent but no reset_after_seconds
// This should cause the caller to use the default 5-minute fallback
svc := &RateLimitService{}
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
// No reset_after_seconds!
resetAt := svc.calculateOpenAI429ResetTime(headers)
// Should return nil since there's no reset time available
if resetAt != nil {
t.Errorf("expected nil when no reset_after_seconds, got %v", resetAt)
}
}

View File

@@ -49,6 +49,11 @@ type RedeemCodeRepository interface {
List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error)
// ListByUserPaginated returns paginated balance/concurrency history for a specific user.
// codeType filter is optional - pass empty string to return all types.
ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error)
// SumPositiveBalanceByUser returns the total recharged amount (sum of positive balance values) for a user.
SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error)
}
// GenerateCodesRequest 生成兑换码请求
@@ -126,7 +131,8 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
return nil, errors.New("count must be greater than 0")
}
if req.Value <= 0 {
// 邀请码类型不需要数值,其他类型需要
if req.Type != RedeemTypeInvitation && req.Value <= 0 {
return nil, errors.New("value must be greater than 0")
}
@@ -139,6 +145,12 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
codeType = RedeemTypeBalance
}
// 邀请码类型的 value 设为 0
value := req.Value
if codeType == RedeemTypeInvitation {
value = 0
}
codes := make([]RedeemCode, 0, req.Count)
for i := 0; i < req.Count; i++ {
code, err := s.GenerateRandomCode()
@@ -149,7 +161,7 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
codes = append(codes, RedeemCode{
Code: code,
Type: codeType,
Value: req.Value,
Value: value,
Status: StatusUnused,
})
}

View File

@@ -61,6 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
SettingKeyPromoCodeEnabled,
SettingKeyPasswordResetEnabled,
SettingKeyInvitationCodeEnabled,
SettingKeyTotpEnabled,
SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey,
SettingKeySiteName,
@@ -71,6 +74,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyDocURL,
SettingKeyHomeContent,
SettingKeyHideCcsImportButton,
SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL,
SettingKeyLinuxDoConnectEnabled,
}
@@ -86,21 +91,30 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
}
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true"
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
LinuxDoOAuthEnabled: linuxDoEnabled,
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: passwordResetEnabled,
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil
}
@@ -125,37 +139,47 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
// Return a struct that matches the frontend's expected format
return &struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo,omitempty"`
SiteSubtitle string `json:"site_subtitle,omitempty"`
APIBaseURL string `json:"api_base_url,omitempty"`
ContactInfo string `json:"contact_info,omitempty"`
DocURL string `json:"doc_url,omitempty"`
HomeContent string `json:"home_content,omitempty"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo,omitempty"`
SiteSubtitle string `json:"site_subtitle,omitempty"`
APIBaseURL string `json:"api_base_url,omitempty"`
ContactInfo string `json:"contact_info,omitempty"`
DocURL string `json:"doc_url,omitempty"`
HomeContent string `json:"home_content,omitempty"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
}, nil
}
@@ -167,6 +191,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySMTPHost] = settings.SMTPHost
@@ -203,6 +230,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyDocURL] = settings.DocURL
updates[SettingKeyHomeContent] = settings.HomeContent
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
@@ -262,6 +291,44 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
return value != "false"
}
// IsInvitationCodeEnabled 检查是否启用邀请码注册功能
func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyInvitationCodeEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
// Password reset requires email verification to be enabled
if !s.IsEmailVerifyEnabled(ctx) {
return false
}
value, err := s.settingRepo.GetValue(ctx, SettingKeyPasswordResetEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// IsTotpEnabled 检查是否启用 TOTP 双因素认证功能
func (s *SettingService) IsTotpEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyTotpEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// IsTotpEncryptionKeyConfigured 检查 TOTP 加密密钥是否已手动配置
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
func (s *SettingService) IsTotpEncryptionKeyConfigured() bool {
return s.cfg.Totp.EncryptionKeyConfigured
}
// GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
@@ -309,15 +376,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
defaults := map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
@@ -340,10 +409,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
@@ -361,6 +434,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
}
// 解析整数类型

Some files were not shown because too many files have changed in this diff Show More