mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-21 07:04:45 +08:00
Merge branch 'main' into test
This commit is contained in:
@@ -47,6 +47,7 @@ type Config struct {
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Ops OpsConfig `mapstructure:"ops"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Totp TotpConfig `mapstructure:"totp"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
@@ -479,6 +480,8 @@ type RedisConfig struct {
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
|
||||
MinIdleConns int `mapstructure:"min_idle_conns"`
|
||||
// EnableTLS: 是否启用 TLS/SSL 连接
|
||||
EnableTLS bool `mapstructure:"enable_tls"`
|
||||
}
|
||||
|
||||
func (r *RedisConfig) Address() string {
|
||||
@@ -531,6 +534,16 @@ type JWTConfig struct {
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
}
|
||||
|
||||
// TotpConfig TOTP 双因素认证配置
|
||||
type TotpConfig struct {
|
||||
// EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥(32 字节 hex 编码)
|
||||
// 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
// EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
|
||||
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
|
||||
EncryptionKeyConfigured bool `mapstructure:"-"`
|
||||
}
|
||||
|
||||
type TurnstileConfig struct {
|
||||
Required bool `mapstructure:"required"`
|
||||
}
|
||||
@@ -691,6 +704,20 @@ func Load() (*Config, error) {
|
||||
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||
}
|
||||
|
||||
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||
if cfg.Totp.EncryptionKey == "" {
|
||||
key, err := generateJWTSecret(32) // Reuse the same random generation function
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
|
||||
}
|
||||
cfg.Totp.EncryptionKey = key
|
||||
cfg.Totp.EncryptionKeyConfigured = false
|
||||
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
|
||||
} else {
|
||||
cfg.Totp.EncryptionKeyConfigured = true
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate config error: %w", err)
|
||||
}
|
||||
@@ -802,6 +829,7 @@ func setDefaults() {
|
||||
viper.SetDefault("redis.write_timeout_seconds", 3)
|
||||
viper.SetDefault("redis.pool_size", 128)
|
||||
viper.SetDefault("redis.min_idle_conns", 10)
|
||||
viper.SetDefault("redis.enable_tls", false)
|
||||
|
||||
// Ops (vNext)
|
||||
viper.SetDefault("ops.enabled", true)
|
||||
@@ -821,6 +849,9 @@ func setDefaults() {
|
||||
viper.SetDefault("jwt.secret", "")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
|
||||
// TOTP
|
||||
viper.SetDefault("totp.encryption_key", "")
|
||||
|
||||
// Default
|
||||
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
|
||||
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
|
||||
|
||||
226
backend/internal/domain/announcement.go
Normal file
226
backend/internal/domain/announcement.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementStatusDraft = "draft"
|
||||
AnnouncementStatusActive = "active"
|
||||
AnnouncementStatusArchived = "archived"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = "subscription"
|
||||
AnnouncementConditionTypeBalance = "balance"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementOperatorIn = "in"
|
||||
AnnouncementOperatorGT = "gt"
|
||||
AnnouncementOperatorGTE = "gte"
|
||||
AnnouncementOperatorLT = "lt"
|
||||
AnnouncementOperatorLTE = "lte"
|
||||
AnnouncementOperatorEQ = "eq"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAnnouncementNotFound = infraerrors.NotFound("ANNOUNCEMENT_NOT_FOUND", "announcement not found")
|
||||
ErrAnnouncementInvalidTarget = infraerrors.BadRequest("ANNOUNCEMENT_INVALID_TARGET", "invalid announcement targeting rules")
|
||||
)
|
||||
|
||||
type AnnouncementTargeting struct {
|
||||
// AnyOf 表示 OR:任意一个条件组满足即可展示。
|
||||
AnyOf []AnnouncementConditionGroup `json:"any_of,omitempty"`
|
||||
}
|
||||
|
||||
type AnnouncementConditionGroup struct {
|
||||
// AllOf 表示 AND:组内所有条件都满足才算命中该组。
|
||||
AllOf []AnnouncementCondition `json:"all_of,omitempty"`
|
||||
}
|
||||
|
||||
type AnnouncementCondition struct {
|
||||
// Type: subscription | balance
|
||||
Type string `json:"type"`
|
||||
|
||||
// Operator:
|
||||
// - subscription: in
|
||||
// - balance: gt/gte/lt/lte/eq
|
||||
Operator string `json:"operator"`
|
||||
|
||||
// subscription 条件:匹配的订阅套餐(group_id)
|
||||
GroupIDs []int64 `json:"group_ids,omitempty"`
|
||||
|
||||
// balance 条件:比较阈值
|
||||
Value float64 `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
func (t AnnouncementTargeting) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
|
||||
// 空规则:展示给所有用户
|
||||
if len(t.AnyOf) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, group := range t.AnyOf {
|
||||
if len(group.AllOf) == 0 {
|
||||
// 空条件组不命中(避免 OR 中出现无条件 “全命中”)
|
||||
continue
|
||||
}
|
||||
allMatched := true
|
||||
for _, cond := range group.AllOf {
|
||||
if !cond.Matches(balance, activeSubscriptionGroupIDs) {
|
||||
allMatched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allMatched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c AnnouncementCondition) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
|
||||
switch c.Type {
|
||||
case AnnouncementConditionTypeSubscription:
|
||||
if c.Operator != AnnouncementOperatorIn {
|
||||
return false
|
||||
}
|
||||
if len(c.GroupIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
if len(activeSubscriptionGroupIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, gid := range c.GroupIDs {
|
||||
if _, ok := activeSubscriptionGroupIDs[gid]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
case AnnouncementConditionTypeBalance:
|
||||
switch c.Operator {
|
||||
case AnnouncementOperatorGT:
|
||||
return balance > c.Value
|
||||
case AnnouncementOperatorGTE:
|
||||
return balance >= c.Value
|
||||
case AnnouncementOperatorLT:
|
||||
return balance < c.Value
|
||||
case AnnouncementOperatorLTE:
|
||||
return balance <= c.Value
|
||||
case AnnouncementOperatorEQ:
|
||||
return balance == c.Value
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (t AnnouncementTargeting) NormalizeAndValidate() (AnnouncementTargeting, error) {
|
||||
normalized := AnnouncementTargeting{AnyOf: make([]AnnouncementConditionGroup, 0, len(t.AnyOf))}
|
||||
|
||||
// 允许空 targeting(展示给所有用户)
|
||||
if len(t.AnyOf) == 0 {
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
if len(t.AnyOf) > 50 {
|
||||
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
|
||||
}
|
||||
|
||||
for _, g := range t.AnyOf {
|
||||
if len(g.AllOf) == 0 {
|
||||
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
|
||||
}
|
||||
if len(g.AllOf) > 50 {
|
||||
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
|
||||
}
|
||||
|
||||
group := AnnouncementConditionGroup{AllOf: make([]AnnouncementCondition, 0, len(g.AllOf))}
|
||||
for _, c := range g.AllOf {
|
||||
cond := AnnouncementCondition{
|
||||
Type: strings.TrimSpace(c.Type),
|
||||
Operator: strings.TrimSpace(c.Operator),
|
||||
Value: c.Value,
|
||||
}
|
||||
for _, gid := range c.GroupIDs {
|
||||
if gid <= 0 {
|
||||
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
|
||||
}
|
||||
cond.GroupIDs = append(cond.GroupIDs, gid)
|
||||
}
|
||||
|
||||
if err := cond.validate(); err != nil {
|
||||
return AnnouncementTargeting{}, err
|
||||
}
|
||||
group.AllOf = append(group.AllOf, cond)
|
||||
}
|
||||
|
||||
normalized.AnyOf = append(normalized.AnyOf, group)
|
||||
}
|
||||
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func (c AnnouncementCondition) validate() error {
|
||||
switch c.Type {
|
||||
case AnnouncementConditionTypeSubscription:
|
||||
if c.Operator != AnnouncementOperatorIn {
|
||||
return ErrAnnouncementInvalidTarget
|
||||
}
|
||||
if len(c.GroupIDs) == 0 {
|
||||
return ErrAnnouncementInvalidTarget
|
||||
}
|
||||
return nil
|
||||
|
||||
case AnnouncementConditionTypeBalance:
|
||||
switch c.Operator {
|
||||
case AnnouncementOperatorGT, AnnouncementOperatorGTE, AnnouncementOperatorLT, AnnouncementOperatorLTE, AnnouncementOperatorEQ:
|
||||
return nil
|
||||
default:
|
||||
return ErrAnnouncementInvalidTarget
|
||||
}
|
||||
|
||||
default:
|
||||
return ErrAnnouncementInvalidTarget
|
||||
}
|
||||
}
|
||||
|
||||
type Announcement struct {
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (a *Announcement) IsActiveAt(now time.Time) bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
if a.Status != AnnouncementStatusActive {
|
||||
return false
|
||||
}
|
||||
if a.StartsAt != nil && now.Before(*a.StartsAt) {
|
||||
return false
|
||||
}
|
||||
if a.EndsAt != nil && !now.Before(*a.EndsAt) {
|
||||
// ends_at 语义:到点即下线
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
66
backend/internal/domain/constants.go
Normal file
66
backend/internal/domain/constants.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package domain
|
||||
|
||||
// Status constants
|
||||
const (
|
||||
StatusActive = "active"
|
||||
StatusDisabled = "disabled"
|
||||
StatusError = "error"
|
||||
StatusUnused = "unused"
|
||||
StatusUsed = "used"
|
||||
StatusExpired = "expired"
|
||||
)
|
||||
|
||||
// Role constants
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
)
|
||||
|
||||
// Platform constants
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformSora = "sora"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
const (
|
||||
RedeemTypeBalance = "balance"
|
||||
RedeemTypeConcurrency = "concurrency"
|
||||
RedeemTypeSubscription = "subscription"
|
||||
RedeemTypeInvitation = "invitation"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
const (
|
||||
PromoCodeStatusActive = "active"
|
||||
PromoCodeStatusDisabled = "disabled"
|
||||
)
|
||||
|
||||
// Admin adjustment type constants
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
|
||||
)
|
||||
|
||||
// Group subscription type constants
|
||||
const (
|
||||
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
|
||||
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
|
||||
)
|
||||
|
||||
// Subscription status constants
|
||||
const (
|
||||
SubscriptionStatusActive = "active"
|
||||
SubscriptionStatusExpired = "expired"
|
||||
SubscriptionStatusSuspended = "suspended"
|
||||
)
|
||||
@@ -547,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
|
||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||
newCredentials["project_id"] = oldProjectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
@@ -557,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
}
|
||||
// 标记账户为 error
|
||||
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
|
||||
response.InternalError(c, "Failed to set account error: "+setErr.Error())
|
||||
return
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed but project_id is missing, account marked as error",
|
||||
"warning": "missing_project_id",
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -290,5 +290,9 @@ func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*ser
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]service.RedeemCode, int64, float64, error) {
|
||||
return s.redeems, int64(len(s.redeems)), 100.0, nil
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
246
backend/internal/handler/admin/announcement_handler.go
Normal file
246
backend/internal/handler/admin/announcement_handler.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AnnouncementHandler handles admin announcement management
|
||||
type AnnouncementHandler struct {
|
||||
announcementService *service.AnnouncementService
|
||||
}
|
||||
|
||||
// NewAnnouncementHandler creates a new admin announcement handler
|
||||
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
|
||||
return &AnnouncementHandler{
|
||||
announcementService: announcementService,
|
||||
}
|
||||
}
|
||||
|
||||
type CreateAnnouncementRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
|
||||
}
|
||||
|
||||
type UpdateAnnouncementRequest struct {
|
||||
Title *string `json:"title"`
|
||||
Content *string `json:"content"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting *service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
|
||||
}
|
||||
|
||||
// List handles listing announcements with filters
|
||||
// GET /api/v1/admin/announcements
|
||||
func (h *AnnouncementHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := strings.TrimSpace(c.Query("status"))
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 200 {
|
||||
search = search[:200]
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
items, paginationResult, err := h.announcementService.List(
|
||||
c.Request.Context(),
|
||||
params,
|
||||
service.AnnouncementListFilters{Status: status, Search: search},
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.Announcement, 0, len(items))
|
||||
for i := range items {
|
||||
out = append(out, *dto.AnnouncementFromService(&items[i]))
|
||||
}
|
||||
response.Paginated(c, out, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting an announcement by ID
|
||||
// GET /api/v1/admin/announcements/:id
|
||||
func (h *AnnouncementHandler) GetByID(c *gin.Context) {
|
||||
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || announcementID <= 0 {
|
||||
response.BadRequest(c, "Invalid announcement ID")
|
||||
return
|
||||
}
|
||||
|
||||
item, err := h.announcementService.GetByID(c.Request.Context(), announcementID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AnnouncementFromService(item))
|
||||
}
|
||||
|
||||
// Create handles creating a new announcement
|
||||
// POST /api/v1/admin/announcements
|
||||
func (h *AnnouncementHandler) Create(c *gin.Context) {
|
||||
var req CreateAnnouncementRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.CreateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil && *req.StartsAt > 0 {
|
||||
t := time.Unix(*req.StartsAt, 0)
|
||||
input.StartsAt = &t
|
||||
}
|
||||
if req.EndsAt != nil && *req.EndsAt > 0 {
|
||||
t := time.Unix(*req.EndsAt, 0)
|
||||
input.EndsAt = &t
|
||||
}
|
||||
|
||||
created, err := h.announcementService.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AnnouncementFromService(created))
|
||||
}
|
||||
|
||||
// Update handles updating an announcement
|
||||
// PUT /api/v1/admin/announcements/:id
|
||||
func (h *AnnouncementHandler) Update(c *gin.Context) {
|
||||
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || announcementID <= 0 {
|
||||
response.BadRequest(c, "Invalid announcement ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAnnouncementRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.UpdateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil {
|
||||
if *req.StartsAt == 0 {
|
||||
var cleared *time.Time = nil
|
||||
input.StartsAt = &cleared
|
||||
} else {
|
||||
t := time.Unix(*req.StartsAt, 0)
|
||||
ptr := &t
|
||||
input.StartsAt = &ptr
|
||||
}
|
||||
}
|
||||
|
||||
if req.EndsAt != nil {
|
||||
if *req.EndsAt == 0 {
|
||||
var cleared *time.Time = nil
|
||||
input.EndsAt = &cleared
|
||||
} else {
|
||||
t := time.Unix(*req.EndsAt, 0)
|
||||
ptr := &t
|
||||
input.EndsAt = &ptr
|
||||
}
|
||||
}
|
||||
|
||||
updated, err := h.announcementService.Update(c.Request.Context(), announcementID, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AnnouncementFromService(updated))
|
||||
}
|
||||
|
||||
// Delete handles deleting an announcement
|
||||
// DELETE /api/v1/admin/announcements/:id
|
||||
func (h *AnnouncementHandler) Delete(c *gin.Context) {
|
||||
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || announcementID <= 0 {
|
||||
response.BadRequest(c, "Invalid announcement ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.announcementService.Delete(c.Request.Context(), announcementID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Announcement deleted successfully"})
|
||||
}
|
||||
|
||||
// ListReadStatus handles listing users read status for an announcement
|
||||
// GET /api/v1/admin/announcements/:id/read-status
|
||||
func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
|
||||
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || announcementID <= 0 {
|
||||
response.BadRequest(c, "Invalid announcement ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 200 {
|
||||
search = search[:200]
|
||||
}
|
||||
|
||||
items, paginationResult, err := h.announcementService.ListUserReadStatus(
|
||||
c.Request.Context(),
|
||||
announcementID,
|
||||
params,
|
||||
search,
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, items, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
@@ -47,6 +47,8 @@ type CreateGroupRequest struct {
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
@@ -74,6 +76,8 @@ type UpdateGroupRequest struct {
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
|
||||
// List handles listing all groups with pagination
|
||||
@@ -183,6 +187,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -229,6 +234,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@@ -29,7 +29,7 @@ func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
|
||||
// GenerateRedeemCodesRequest represents generate redeem codes request
|
||||
type GenerateRedeemCodesRequest struct {
|
||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value" binding:"min=0"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
|
||||
|
||||
@@ -48,6 +48,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
@@ -70,6 +74,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@@ -89,9 +95,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
@@ -114,14 +123,16 @@ type UpdateSettingsRequest struct {
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -198,6 +209,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// TOTP 双因素认证参数验证
|
||||
// 只有手动配置了加密密钥才允许启用 TOTP 功能
|
||||
if req.TotpEnabled && !previousSettings.TotpEnabled {
|
||||
// 尝试启用 TOTP,检查加密密钥是否已手动配置
|
||||
if !h.settingService.IsTotpEncryptionKeyConfigured() {
|
||||
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// LinuxDo Connect 参数验证
|
||||
if req.LinuxDoConnectEnabled {
|
||||
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||
@@ -227,6 +248,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// “购买订阅”页面配置验证
|
||||
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
|
||||
if req.PurchaseSubscriptionEnabled != nil {
|
||||
purchaseEnabled = *req.PurchaseSubscriptionEnabled
|
||||
}
|
||||
purchaseURL := previousSettings.PurchaseSubscriptionURL
|
||||
if req.PurchaseSubscriptionURL != nil {
|
||||
purchaseURL = strings.TrimSpace(*req.PurchaseSubscriptionURL)
|
||||
}
|
||||
|
||||
// - 启用时要求 URL 合法且非空
|
||||
// - 禁用时允许为空;若提供了 URL 也做基本校验,避免误配置
|
||||
if purchaseEnabled {
|
||||
if purchaseURL == "" {
|
||||
response.BadRequest(c, "Purchase Subscription URL is required when enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
|
||||
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
} else if purchaseURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
|
||||
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Ops metrics collector interval validation (seconds).
|
||||
if req.OpsMetricsIntervalSeconds != nil {
|
||||
v := *req.OpsMetricsIntervalSeconds
|
||||
@@ -240,40 +289,45 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -318,6 +372,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
@@ -340,6 +398,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
DocURL: updatedSettings.DocURL,
|
||||
HomeContent: updatedSettings.HomeContent,
|
||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@@ -384,6 +444,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||
changed = append(changed, "email_verify_enabled")
|
||||
}
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
if before.SMTPHost != after.SMTPHost {
|
||||
changed = append(changed, "smtp_host")
|
||||
}
|
||||
|
||||
@@ -77,7 +77,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
}
|
||||
status := c.Query("status")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
||||
// Parse sorting parameters
|
||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -277,3 +277,44 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// GetBalanceHistory handles getting user's balance/concurrency change history
|
||||
// GET /api/v1/admin/users/:id/balance-history
|
||||
// Query params:
|
||||
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
codeType := c.Query("type")
|
||||
|
||||
codes, total, totalRecharged, err := h.adminService.GetUserBalanceHistory(c.Request.Context(), userID, page, pageSize, codeType)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to admin DTO (includes notes field for admin visibility)
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
|
||||
// Custom response with total_recharged alongside pagination
|
||||
pages := int((total + int64(pageSize) - 1) / int64(pageSize))
|
||||
if pages < 1 {
|
||||
pages = 1
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"items": out,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"pages": pages,
|
||||
"total_recharged": totalRecharged,
|
||||
})
|
||||
}
|
||||
|
||||
81
backend/internal/handler/announcement_handler.go
Normal file
81
backend/internal/handler/announcement_handler.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AnnouncementHandler handles user announcement operations
|
||||
type AnnouncementHandler struct {
|
||||
announcementService *service.AnnouncementService
|
||||
}
|
||||
|
||||
// NewAnnouncementHandler creates a new user announcement handler
|
||||
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
|
||||
return &AnnouncementHandler{
|
||||
announcementService: announcementService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing announcements visible to current user
|
||||
// GET /api/v1/announcements
|
||||
func (h *AnnouncementHandler) List(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
unreadOnly := parseBoolQuery(c.Query("unread_only"))
|
||||
|
||||
items, err := h.announcementService.ListForUser(c.Request.Context(), subject.UserID, unreadOnly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserAnnouncement, 0, len(items))
|
||||
for i := range items {
|
||||
out = append(out, *dto.UserAnnouncementFromService(&items[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// MarkRead marks an announcement as read for current user
|
||||
// POST /api/v1/announcements/:id/read
|
||||
func (h *AnnouncementHandler) MarkRead(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || announcementID <= 0 {
|
||||
response.BadRequest(c, "Invalid announcement ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.announcementService.MarkRead(c.Request.Context(), subject.UserID, announcementID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
func parseBoolQuery(v string) bool {
|
||||
switch strings.TrimSpace(strings.ToLower(v)) {
|
||||
case "1", "true", "yes", "y", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
@@ -13,21 +15,25 @@ import (
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
redeemService *service.RedeemService
|
||||
totpService *service.TotpService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
redeemService: redeemService,
|
||||
totpService: totpService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +43,8 @@ type RegisterRequest struct {
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
VerifyCode string `json:"verify_code"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
PromoCode string `json:"promo_code"` // 注册优惠码
|
||||
PromoCode string `json:"promo_code"` // 注册优惠码
|
||||
InvitationCode string `json:"invitation_code"` // 邀请码
|
||||
}
|
||||
|
||||
// SendVerifyCodeRequest 发送验证码请求
|
||||
@@ -83,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -144,6 +151,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if TOTP 2FA is enabled for this user
|
||||
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
|
||||
// Create a temporary login session for 2FA
|
||||
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create 2FA session")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, TotpLoginResponse{
|
||||
Requires2FA: true,
|
||||
TempToken: tempToken,
|
||||
UserEmailMasked: service.MaskEmail(user.Email),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// TotpLoginResponse represents the response when 2FA is required
|
||||
type TotpLoginResponse struct {
|
||||
Requires2FA bool `json:"requires_2fa"`
|
||||
TempToken string `json:"temp_token,omitempty"`
|
||||
UserEmailMasked string `json:"user_email_masked,omitempty"`
|
||||
}
|
||||
|
||||
// Login2FARequest represents the 2FA login request
|
||||
type Login2FARequest struct {
|
||||
TempToken string `json:"temp_token" binding:"required"`
|
||||
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||
}
|
||||
|
||||
// Login2FA completes the login with 2FA verification
|
||||
// POST /api/v1/auth/login/2fa
|
||||
func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
var req Login2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("login_2fa_request",
|
||||
"temp_token_len", len(req.TempToken),
|
||||
"totp_code_len", len(req.TotpCode))
|
||||
|
||||
// Get the login session
|
||||
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
|
||||
if err != nil || session == nil {
|
||||
tokenPrefix := ""
|
||||
if len(req.TempToken) >= 8 {
|
||||
tokenPrefix = req.TempToken[:8]
|
||||
}
|
||||
slog.Debug("login_2fa_session_invalid",
|
||||
"temp_token_prefix", tokenPrefix,
|
||||
"error", err)
|
||||
response.BadRequest(c, "Invalid or expired 2FA session")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("login_2fa_session_found",
|
||||
"user_id", session.UserID,
|
||||
"email", session.Email)
|
||||
|
||||
// Verify the TOTP code
|
||||
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
|
||||
slog.Debug("login_2fa_verify_failed",
|
||||
"user_id", session.UserID,
|
||||
"error", err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate the JWT token
|
||||
token, err := h.authService.GenerateToken(user)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
@@ -247,3 +348,146 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateInvitationCodeRequest 验证邀请码请求
|
||||
type ValidateInvitationCodeRequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// ValidateInvitationCodeResponse 验证邀请码响应
|
||||
type ValidateInvitationCodeResponse struct {
|
||||
Valid bool `json:"valid"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
}
|
||||
|
||||
// ValidateInvitationCode 验证邀请码(公开接口,注册前调用)
|
||||
// POST /api/v1/auth/validate-invitation-code
|
||||
func (h *AuthHandler) ValidateInvitationCode(c *gin.Context) {
|
||||
// 检查邀请码功能是否启用
|
||||
if h.settingSvc == nil || !h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) {
|
||||
response.Success(c, ValidateInvitationCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "INVITATION_CODE_DISABLED",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req ValidateInvitationCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证邀请码
|
||||
redeemCode, err := h.redeemService.GetByCode(c.Request.Context(), req.Code)
|
||||
if err != nil {
|
||||
response.Success(c, ValidateInvitationCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "INVITATION_CODE_NOT_FOUND",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查类型和状态
|
||||
if redeemCode.Type != service.RedeemTypeInvitation {
|
||||
response.Success(c, ValidateInvitationCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "INVITATION_CODE_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if redeemCode.Status != service.StatusUnused {
|
||||
response.Success(c, ValidateInvitationCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "INVITATION_CODE_USED",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ValidateInvitationCodeResponse{
|
||||
Valid: true,
|
||||
})
|
||||
}
|
||||
|
||||
// ForgotPasswordRequest 忘记密码请求
|
||||
type ForgotPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// ForgotPasswordResponse 忘记密码响应
|
||||
type ForgotPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ForgotPassword 请求密码重置
|
||||
// POST /api/v1/auth/forgot-password
|
||||
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
var req ForgotPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build frontend base URL from request
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
// Check X-Forwarded-Proto header (common in reverse proxy setups)
|
||||
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
frontendBaseURL := scheme + "://" + c.Request.Host
|
||||
|
||||
// Request password reset (async)
|
||||
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ForgotPasswordResponse{
|
||||
Message: "If your email is registered, you will receive a password reset link shortly.",
|
||||
})
|
||||
}
|
||||
|
||||
// ResetPasswordRequest 重置密码请求
|
||||
type ResetPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Token string `json:"token" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// ResetPasswordResponse 重置密码响应
|
||||
type ResetPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// POST /api/v1/auth/reset-password
|
||||
func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||
var req ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Reset password
|
||||
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ResetPasswordResponse{
|
||||
Message: "Your password has been reset successfully. You can now log in with your new password.",
|
||||
})
|
||||
}
|
||||
|
||||
74
backend/internal/handler/dto/announcement.go
Normal file
74
backend/internal/handler/dto/announcement.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type Announcement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
|
||||
StartsAt *time.Time `json:"starts_at,omitempty"`
|
||||
EndsAt *time.Time `json:"ends_at,omitempty"`
|
||||
|
||||
CreatedBy *int64 `json:"created_by,omitempty"`
|
||||
UpdatedBy *int64 `json:"updated_by,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserAnnouncement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
|
||||
StartsAt *time.Time `json:"starts_at,omitempty"`
|
||||
EndsAt *time.Time `json:"ends_at,omitempty"`
|
||||
|
||||
ReadAt *time.Time `json:"read_at,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func AnnouncementFromService(a *service.Announcement) *Announcement {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return &Announcement{
|
||||
ID: a.ID,
|
||||
Title: a.Title,
|
||||
Content: a.Content,
|
||||
Status: a.Status,
|
||||
Targeting: a.Targeting,
|
||||
StartsAt: a.StartsAt,
|
||||
EndsAt: a.EndsAt,
|
||||
CreatedBy: a.CreatedBy,
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserAnnouncement{
|
||||
ID: a.Announcement.ID,
|
||||
Title: a.Announcement.Title,
|
||||
Content: a.Announcement.Content,
|
||||
StartsAt: a.Announcement.StartsAt,
|
||||
EndsAt: a.Announcement.EndsAt,
|
||||
ReadAt: a.ReadAt,
|
||||
CreatedAt: a.Announcement.CreatedAt,
|
||||
UpdatedAt: a.Announcement.UpdatedAt,
|
||||
}
|
||||
}
|
||||
@@ -208,6 +208,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
if scopeLimits := a.GetAntigravityScopeRateLimits(); len(scopeLimits) > 0 {
|
||||
out.ScopeRateLimits = make(map[string]ScopeRateLimitInfo, len(scopeLimits))
|
||||
now := time.Now()
|
||||
for scope, remainingSec := range scopeLimits {
|
||||
out.ScopeRateLimits[scope] = ScopeRateLimitInfo{
|
||||
ResetAt: now.Add(time.Duration(remainingSec) * time.Second),
|
||||
RemainingSec: remainingSec,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -325,7 +336,7 @@ func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
|
||||
}
|
||||
|
||||
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
|
||||
return RedeemCode{
|
||||
out := RedeemCode{
|
||||
ID: rc.ID,
|
||||
Code: rc.Code,
|
||||
Type: rc.Type,
|
||||
@@ -339,6 +350,14 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
|
||||
User: UserFromServiceShallow(rc.User),
|
||||
Group: GroupFromServiceShallow(rc.Group),
|
||||
}
|
||||
|
||||
// For admin_balance/admin_concurrency types, include notes so users can see
|
||||
// why they were charged or credited by admin
|
||||
if (rc.Type == "admin_balance" || rc.Type == "admin_concurrency") && rc.Notes != "" {
|
||||
out.Notes = &rc.Notes
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// AccountSummaryFromService returns a minimal AccountSummary for usage log display.
|
||||
@@ -362,6 +381,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
InputTokens: l.InputTokens,
|
||||
|
||||
@@ -2,9 +2,13 @@ package dto
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
@@ -23,14 +27,16 @@ type SystemSettings struct {
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -54,21 +60,26 @@ type SystemSettings struct {
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
|
||||
@@ -2,6 +2,11 @@ package dto
|
||||
|
||||
import "time"
|
||||
|
||||
type ScopeRateLimitInfo struct {
|
||||
ResetAt time.Time `json:"reset_at"`
|
||||
RemainingSec int64 `json:"remaining_sec"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
@@ -114,6 +119,9 @@ type Account struct {
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
|
||||
// Antigravity scope 级限流状态(从 extra 提取)
|
||||
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
|
||||
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
||||
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
||||
|
||||
@@ -204,6 +212,10 @@ type RedeemCode struct {
|
||||
GroupID *int64 `json:"group_id"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
|
||||
// Notes is only populated for admin_balance/admin_concurrency types
|
||||
// so users can see why they were charged or credited
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
@@ -224,6 +236,9 @@ type UsageLog struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
SubscriptionID *int64 `json:"subscription_id"`
|
||||
|
||||
@@ -30,6 +30,7 @@ type GatewayHandler struct {
|
||||
antigravityGatewayService *service.AntigravityGatewayService
|
||||
userService *service.UserService
|
||||
billingCacheService *service.BillingCacheService
|
||||
usageService *service.UsageService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
maxAccountSwitchesGemini int
|
||||
@@ -44,6 +45,7 @@ func NewGatewayHandler(
|
||||
userService *service.UserService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
usageService *service.UsageService,
|
||||
cfg *config.Config,
|
||||
) *GatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
@@ -64,6 +66,7 @@ func NewGatewayHandler(
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
userService: userService,
|
||||
billingCacheService: billingCacheService,
|
||||
usageService: usageService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||
@@ -537,7 +540,7 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// Usage handles getting account balance for CC Switch integration
|
||||
// Usage handles getting account balance and usage statistics for CC Switch integration
|
||||
// GET /v1/usage
|
||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
@@ -552,7 +555,40 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 订阅模式:返回订阅限额信息
|
||||
// Best-effort: 获取用量统计,失败不影响基础响应
|
||||
var usageData gin.H
|
||||
if h.usageService != nil {
|
||||
dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
|
||||
if err == nil && dashStats != nil {
|
||||
usageData = gin.H{
|
||||
"today": gin.H{
|
||||
"requests": dashStats.TodayRequests,
|
||||
"input_tokens": dashStats.TodayInputTokens,
|
||||
"output_tokens": dashStats.TodayOutputTokens,
|
||||
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
|
||||
"cache_read_tokens": dashStats.TodayCacheReadTokens,
|
||||
"total_tokens": dashStats.TodayTokens,
|
||||
"cost": dashStats.TodayCost,
|
||||
"actual_cost": dashStats.TodayActualCost,
|
||||
},
|
||||
"total": gin.H{
|
||||
"requests": dashStats.TotalRequests,
|
||||
"input_tokens": dashStats.TotalInputTokens,
|
||||
"output_tokens": dashStats.TotalOutputTokens,
|
||||
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
|
||||
"cache_read_tokens": dashStats.TotalCacheReadTokens,
|
||||
"total_tokens": dashStats.TotalTokens,
|
||||
"cost": dashStats.TotalCost,
|
||||
"actual_cost": dashStats.TotalActualCost,
|
||||
},
|
||||
"average_duration_ms": dashStats.AverageDurationMs,
|
||||
"rpm": dashStats.Rpm,
|
||||
"tpm": dashStats.Tpm,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 订阅模式:返回订阅限额信息 + 用量统计
|
||||
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
|
||||
subscription, ok := middleware2.GetSubscriptionFromContext(c)
|
||||
if !ok {
|
||||
@@ -561,28 +597,46 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
}
|
||||
|
||||
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
resp := gin.H{
|
||||
"isValid": true,
|
||||
"planName": apiKey.Group.Name,
|
||||
"remaining": remaining,
|
||||
"unit": "USD",
|
||||
})
|
||||
"subscription": gin.H{
|
||||
"daily_usage_usd": subscription.DailyUsageUSD,
|
||||
"weekly_usage_usd": subscription.WeeklyUsageUSD,
|
||||
"monthly_usage_usd": subscription.MonthlyUsageUSD,
|
||||
"daily_limit_usd": apiKey.Group.DailyLimitUSD,
|
||||
"weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
|
||||
"monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
|
||||
"expires_at": subscription.ExpiresAt,
|
||||
},
|
||||
}
|
||||
if usageData != nil {
|
||||
resp["usage"] = usageData
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
// 余额模式:返回钱包余额
|
||||
// 余额模式:返回钱包余额 + 用量统计
|
||||
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
resp := gin.H{
|
||||
"isValid": true,
|
||||
"planName": "钱包余额",
|
||||
"remaining": latestUser.Balance,
|
||||
"unit": "USD",
|
||||
})
|
||||
"balance": latestUser.Balance,
|
||||
}
|
||||
if usageData != nil {
|
||||
resp["usage"] = usageData
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// calculateSubscriptionRemaining 计算订阅剩余可用额度
|
||||
@@ -738,6 +792,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
|
||||
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractGeminiCLISessionHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
privilegedUserID string
|
||||
wantEmpty bool
|
||||
wantHash string
|
||||
}{
|
||||
{
|
||||
name: "with privileged-user-id and tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: false,
|
||||
wantHash: func() string {
|
||||
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "without privileged-user-id but with tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||
privilegedUserID: "",
|
||||
wantEmpty: false,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "without tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 创建测试上下文
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/test", nil)
|
||||
if tt.privilegedUserID != "" {
|
||||
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
|
||||
}
|
||||
|
||||
// 调用函数
|
||||
result := extractGeminiCLISessionHash(c, []byte(tt.body))
|
||||
|
||||
// 验证结果
|
||||
if tt.wantEmpty {
|
||||
require.Empty(t, result, "expected empty session hash")
|
||||
} else {
|
||||
require.NotEmpty(t, result, "expected non-empty session hash")
|
||||
require.Equal(t, tt.wantHash, result, "session hash mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiCLITmpDirRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantMatch bool
|
||||
wantHash string
|
||||
}{
|
||||
{
|
||||
name: "valid tmp dir path",
|
||||
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
wantMatch: true,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "valid tmp dir path in text",
|
||||
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
|
||||
wantMatch: true,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "invalid hash length",
|
||||
input: "/Users/ianshaw/.gemini/tmp/abc123",
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "no tmp dir",
|
||||
input: "Hello world",
|
||||
wantMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
|
||||
if tt.wantMatch {
|
||||
require.NotNil(t, match, "expected regex to match")
|
||||
require.Len(t, match, 2, "expected 2 capture groups")
|
||||
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
|
||||
} else {
|
||||
require.Nil(t, match, "expected regex not to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,15 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -19,6 +23,17 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
|
||||
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
||||
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
||||
|
||||
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
|
||||
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
|
||||
return true
|
||||
}
|
||||
return geminiCLITmpDirRegex.Match(body)
|
||||
}
|
||||
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
// 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
|
||||
sessionHash := extractGeminiCLISessionHash(c, body)
|
||||
if sessionHash == "" {
|
||||
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
}
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
|
||||
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
}
|
||||
isCLI := isGeminiCLIRequest(c, body)
|
||||
cleanedForUnknownBinding := false
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||||
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||||
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
|
||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
cleanedForUnknownBinding = true
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionBoundAccountID == 0 {
|
||||
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
|
||||
sessionBoundAccountID = account.ID
|
||||
}
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
@@ -319,18 +366,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 6) record usage async
|
||||
// 6) record usage async (Gemini 使用长上下文双倍计费)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
@@ -433,3 +483,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||||
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||||
//
|
||||
// 会话标识生成策略:
|
||||
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
|
||||
// 2. 从 header 中提取 privileged-user-id(UUID)
|
||||
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
|
||||
//
|
||||
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
|
||||
//
|
||||
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
|
||||
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
|
||||
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
|
||||
// 1. 从请求体中提取 tmp 目录哈希
|
||||
match := geminiCLITmpDirRegex.FindSubmatch(body)
|
||||
if len(match) < 2 {
|
||||
return "" // 没有找到 tmp 目录,不使用粘性会话
|
||||
}
|
||||
tmpDirHash := string(match[1])
|
||||
|
||||
// 2. 提取 privileged-user-id
|
||||
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
|
||||
|
||||
// 3. 组合生成最终的 session hash
|
||||
if privilegedUserID != "" {
|
||||
// 组合两个标识符:privileged-user-id + tmp 目录哈希
|
||||
combined := privilegedUserID + ":" + tmpDirHash
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
|
||||
return tmpDirHash
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ type AdminHandlers struct {
|
||||
User *admin.UserHandler
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
@@ -33,11 +34,13 @@ type Handlers struct {
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Announcement *AnnouncementHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
SoraGateway *SoraGatewayHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
|
||||
@@ -905,7 +905,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
||||
return true
|
||||
}
|
||||
if phase == "billing" || phase == "concurrency" {
|
||||
@@ -1011,5 +1011,12 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
||||
}
|
||||
}
|
||||
|
||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||
if settings.IgnoreInvalidApiKeyErrors {
|
||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -32,20 +32,25 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
181
backend/internal/handler/totp_handler.go
Normal file
181
backend/internal/handler/totp_handler.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TotpHandler handles TOTP-related requests
|
||||
type TotpHandler struct {
|
||||
totpService *service.TotpService
|
||||
}
|
||||
|
||||
// NewTotpHandler creates a new TotpHandler
|
||||
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
|
||||
return &TotpHandler{
|
||||
totpService: totpService,
|
||||
}
|
||||
}
|
||||
|
||||
// TotpStatusResponse represents the TOTP status response
|
||||
type TotpStatusResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
|
||||
FeatureEnabled bool `json:"feature_enabled"`
|
||||
}
|
||||
|
||||
// GetStatus returns the TOTP status for the current user
|
||||
// GET /api/v1/user/totp/status
|
||||
func (h *TotpHandler) GetStatus(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := TotpStatusResponse{
|
||||
Enabled: status.Enabled,
|
||||
FeatureEnabled: status.FeatureEnabled,
|
||||
}
|
||||
|
||||
if status.EnabledAt != nil {
|
||||
ts := status.EnabledAt.Unix()
|
||||
resp.EnabledAt = &ts
|
||||
}
|
||||
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
// TotpSetupRequest represents the request to initiate TOTP setup
|
||||
type TotpSetupRequest struct {
|
||||
EmailCode string `json:"email_code"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// TotpSetupResponse represents the TOTP setup response
|
||||
type TotpSetupResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeURL string `json:"qr_code_url"`
|
||||
SetupToken string `json:"setup_token"`
|
||||
Countdown int `json:"countdown"`
|
||||
}
|
||||
|
||||
// InitiateSetup starts the TOTP setup process
|
||||
// POST /api/v1/user/totp/setup
|
||||
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpSetupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body (optional params)
|
||||
req = TotpSetupRequest{}
|
||||
}
|
||||
|
||||
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, TotpSetupResponse{
|
||||
Secret: result.Secret,
|
||||
QRCodeURL: result.QRCodeURL,
|
||||
SetupToken: result.SetupToken,
|
||||
Countdown: result.Countdown,
|
||||
})
|
||||
}
|
||||
|
||||
// TotpEnableRequest represents the request to enable TOTP
|
||||
type TotpEnableRequest struct {
|
||||
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||
SetupToken string `json:"setup_token" binding:"required"`
|
||||
}
|
||||
|
||||
// Enable completes the TOTP setup
|
||||
// POST /api/v1/user/totp/enable
|
||||
func (h *TotpHandler) Enable(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpEnableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// TotpDisableRequest represents the request to disable TOTP
|
||||
type TotpDisableRequest struct {
|
||||
EmailCode string `json:"email_code"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// Disable disables TOTP for the current user
|
||||
// POST /api/v1/user/totp/disable
|
||||
func (h *TotpHandler) Disable(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpDisableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GetVerificationMethod returns the verification method for TOTP operations
|
||||
// GET /api/v1/user/totp/verification-method
|
||||
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
|
||||
method := h.totpService.GetVerificationMethod(c.Request.Context())
|
||||
response.Success(c, method)
|
||||
}
|
||||
|
||||
// SendVerifyCode sends an email verification code for TOTP operations
|
||||
// POST /api/v1/user/totp/send-code
|
||||
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
@@ -13,6 +13,7 @@ func ProvideAdminHandlers(
|
||||
userHandler *admin.UserHandler,
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -32,6 +33,7 @@ func ProvideAdminHandlers(
|
||||
User: userHandler,
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -66,11 +68,13 @@ func ProvideHandlers(
|
||||
usageHandler *UsageHandler,
|
||||
redeemHandler *RedeemHandler,
|
||||
subscriptionHandler *SubscriptionHandler,
|
||||
announcementHandler *AnnouncementHandler,
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
soraGatewayHandler *SoraGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
@@ -79,11 +83,13 @@ func ProvideHandlers(
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Announcement: announcementHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
SoraGateway: soraGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,9 +102,11 @@ var ProviderSet = wire.NewSet(
|
||||
NewUsageHandler,
|
||||
NewRedeemHandler,
|
||||
NewSubscriptionHandler,
|
||||
NewAnnouncementHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
NewSoraGatewayHandler,
|
||||
NewTotpHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
@@ -106,6 +114,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewUserHandler,
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
@@ -33,7 +33,7 @@ const (
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// User-Agent(与 Antigravity-Manager 保持一致)
|
||||
UserAgent = "antigravity/1.11.9 windows/amd64"
|
||||
UserAgent = "antigravity/1.15.8 windows/amd64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
@@ -367,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
}
|
||||
// 保留原有 signature(Claude 模型需要有效的 signature)
|
||||
if block.Signature != "" {
|
||||
// signature 处理:
|
||||
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if !allowDummyThought {
|
||||
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
|
||||
@@ -407,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
},
|
||||
}
|
||||
// tool_use 的 signature 处理:
|
||||
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验)
|
||||
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路)
|
||||
if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
|
||||
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
}
|
||||
parts = append(parts, part)
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
|
||||
]`
|
||||
|
||||
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
|
||||
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
|
||||
if err != nil {
|
||||
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
if parts[0].ThoughtSignature != "sig_tool_abc" {
|
||||
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
|
||||
contentNoSig := `[
|
||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
|
||||
]`
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
if parts[0].ThoughtSignature != dummyThoughtSignature {
|
||||
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
|
||||
}
|
||||
|
||||
@@ -9,11 +9,26 @@ const (
|
||||
BetaClaudeCode = "claude-code-20250219"
|
||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||
BetaTokenCounting = "token-counting-2024-11-01"
|
||||
)
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
|
||||
//
|
||||
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
|
||||
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
|
||||
// even if the request doesn't use tools, otherwise upstream may reject the
|
||||
// request as a non-Claude-Code API request.
|
||||
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
|
||||
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
||||
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
|
||||
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
|
||||
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
||||
var DefaultHeaders = map[string]string{
|
||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||
// Keep these in sync with recent Claude CLI traffic to reduce the chance
|
||||
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
|
||||
"User-Agent": "claude-cli/2.1.22 (external, cli)",
|
||||
"X-Stainless-Lang": "js",
|
||||
"X-Stainless-Package-Version": "0.52.0",
|
||||
"X-Stainless-Package-Version": "0.70.0",
|
||||
"X-Stainless-OS": "Linux",
|
||||
"X-Stainless-Arch": "x64",
|
||||
"X-Stainless-Arch": "arm64",
|
||||
"X-Stainless-Runtime": "node",
|
||||
"X-Stainless-Runtime-Version": "v22.14.0",
|
||||
"X-Stainless-Runtime-Version": "v24.13.0",
|
||||
"X-Stainless-Retry-Count": "0",
|
||||
"X-Stainless-Timeout": "60",
|
||||
"X-Stainless-Timeout": "600",
|
||||
"X-App": "cli",
|
||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||
}
|
||||
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
|
||||
|
||||
// DefaultTestModel 测试时使用的默认模型
|
||||
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
||||
|
||||
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
|
||||
var ModelIDOverrides = map[string]string{
|
||||
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
|
||||
"claude-opus-4-5": "claude-opus-4-5-20251101",
|
||||
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
|
||||
}
|
||||
|
||||
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
|
||||
var ModelIDReverseOverrides = map[string]string{
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
|
||||
}
|
||||
|
||||
// NormalizeModelID 根据 Claude OAuth 规则映射模型
|
||||
func NormalizeModelID(id string) string {
|
||||
if id == "" {
|
||||
return id
|
||||
}
|
||||
if mapped, ok := ModelIDOverrides[id]; ok {
|
||||
return mapped
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// DenormalizeModelID 将上游模型 ID 转换为短名
|
||||
func DenormalizeModelID(id string) string {
|
||||
if id == "" {
|
||||
return id
|
||||
}
|
||||
if mapped, ok := ModelIDReverseOverrides[id]; ok {
|
||||
return mapped
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
|
||||
}
|
||||
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
|
||||
// Log internal errors with full details for debugging
|
||||
if statusCode >= 500 && c.Request != nil {
|
||||
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
|
||||
}
|
||||
|
||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||
return true
|
||||
}
|
||||
|
||||
278
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
Normal file
278
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
//go:build integration
|
||||
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Integration tests for verifying TLS fingerprint correctness.
|
||||
// These tests make actual network requests to external services and should be run manually.
|
||||
//
|
||||
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// skipIfExternalServiceUnavailable checks if the external service is available.
|
||||
// If not, it skips the test instead of failing.
|
||||
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
// Check for common network/TLS errors that indicate external service issues
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "certificate has expired") ||
|
||||
strings.Contains(errStr, "certificate is not yet valid") ||
|
||||
strings.Contains(errStr, "connection refused") ||
|
||||
strings.Contains(errStr, "no such host") ||
|
||||
strings.Contains(errStr, "network is unreachable") ||
|
||||
strings.Contains(errStr, "timeout") {
|
||||
t.Skipf("skipping test: external service unavailable: %v", err)
|
||||
}
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
|
||||
// This test uses tls.peet.ws to verify the fingerprint.
|
||||
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
|
||||
func TestJA3Fingerprint(t *testing.T) {
|
||||
// Skip if network is unavailable or if running in short mode
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
profile := &Profile{
|
||||
Name: "Claude CLI Test",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Use tls.peet.ws fingerprint detection API
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
skipIfExternalServiceUnavailable(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
}
|
||||
|
||||
// Log all fingerprint information
|
||||
t.Logf("JA3: %s", fpResp.TLS.JA3)
|
||||
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
|
||||
t.Logf("JA4: %s", fpResp.TLS.JA4)
|
||||
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
|
||||
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
|
||||
|
||||
// Verify JA3 hash matches expected value
|
||||
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
|
||||
if fpResp.TLS.JA3Hash == expectedJA3Hash {
|
||||
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
|
||||
} else {
|
||||
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
|
||||
}
|
||||
|
||||
// Verify JA4 fingerprint
|
||||
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
|
||||
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
|
||||
// The suffix _a33745022dd6_1f22a2ca17c4 should match
|
||||
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
|
||||
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
|
||||
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
|
||||
}
|
||||
|
||||
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
|
||||
// d = domain (SNI present), i = IP (no SNI)
|
||||
// Since we connect to tls.peet.ws (domain), we expect 'd'
|
||||
expectedJA4Prefix := "t13d5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
|
||||
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
|
||||
} else {
|
||||
// Also accept 'i' variant for IP connections
|
||||
altPrefix := "t13i5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
|
||||
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
|
||||
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
|
||||
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
|
||||
} else {
|
||||
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
|
||||
}
|
||||
|
||||
// Verify extension list (should be 11 extensions including SNI)
|
||||
// Expected: 0-11-10-35-16-22-23-13-43-45-51
|
||||
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
|
||||
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
|
||||
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
|
||||
} else {
|
||||
t.Logf("Warning: JA3 extension list may differ")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileExpectation defines expected fingerprint values for a profile.
|
||||
type TestProfileExpectation struct {
|
||||
Profile *Profile
|
||||
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
|
||||
ExpectedJA4 string // Expected full JA4 (empty = don't check)
|
||||
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
|
||||
}
|
||||
|
||||
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
|
||||
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
|
||||
func TestAllProfiles(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Define all profiles to test with their expected fingerprints
|
||||
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
|
||||
profiles := []TestProfileExpectation{
|
||||
{
|
||||
// Linux x64 Node.js v22.17.1
|
||||
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
|
||||
Profile: &Profile{
|
||||
Name: "linux_x64_node_v22171",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part
|
||||
},
|
||||
{
|
||||
// MacOS arm64 Node.js v22.18.0
|
||||
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
|
||||
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
|
||||
Profile: &Profile{
|
||||
Name: "macos_arm64_node_v22180",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range profiles {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.Profile.Name, func(t *testing.T) {
|
||||
fp := fetchFingerprint(t, tc.Profile)
|
||||
if fp == nil {
|
||||
return // fetchFingerprint already called t.Fatal
|
||||
}
|
||||
|
||||
t.Logf("Profile: %s", tc.Profile.Name)
|
||||
t.Logf(" JA3: %s", fp.JA3)
|
||||
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
|
||||
t.Logf(" JA4: %s", fp.JA4)
|
||||
t.Logf(" PeetPrint: %s", fp.PeetPrint)
|
||||
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
|
||||
|
||||
// Verify expectations
|
||||
if tc.ExpectedJA3 != "" {
|
||||
if fp.JA3Hash == tc.ExpectedJA3 {
|
||||
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.ExpectedJA4 != "" {
|
||||
if fp.JA4 == tc.ExpectedJA4 {
|
||||
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
|
||||
}
|
||||
}
|
||||
|
||||
// Check JA4 cipher hash (stable middle part)
|
||||
// JA4 format: prefix_cipherHash_extHash
|
||||
if tc.JA4CipherHash != "" {
|
||||
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
|
||||
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
|
||||
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
|
||||
t.Helper()
|
||||
|
||||
dialer := NewDialer(profile, nil)
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
skipIfExternalServiceUnavailable(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &fpResp.TLS
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Integration tests for verifying TLS fingerprint correctness.
|
||||
// These tests make actual network requests and should be run manually.
|
||||
// Unit tests for TLS fingerprint dialer.
|
||||
// Integration tests that require external network are in dialer_integration_test.go
|
||||
// and require the 'integration' build tag.
|
||||
//
|
||||
// Run with: go test -v ./internal/pkg/tlsfingerprint/...
|
||||
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
|
||||
// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
|
||||
// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
|
||||
@@ -809,12 +809,21 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
|
||||
return err
|
||||
}
|
||||
|
||||
path := "{antigravity_quota_scopes," + string(scope) + "}"
|
||||
scopeKey := string(scope)
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
||||
path,
|
||||
`UPDATE accounts SET
|
||||
extra = jsonb_set(
|
||||
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
|
||||
ARRAY['antigravity_quota_scopes', $1]::text[],
|
||||
$2::jsonb,
|
||||
true
|
||||
),
|
||||
updated_at = NOW(),
|
||||
last_used_at = NOW()
|
||||
WHERE id = $3 AND deleted_at IS NULL`,
|
||||
scopeKey,
|
||||
raw,
|
||||
id,
|
||||
)
|
||||
@@ -829,6 +838,7 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
||||
}
|
||||
@@ -849,12 +859,19 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
||||
return err
|
||||
}
|
||||
|
||||
path := "{model_rate_limits," + scope + "}"
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
||||
path,
|
||||
`UPDATE accounts SET
|
||||
extra = jsonb_set(
|
||||
jsonb_set(COALESCE(extra, '{}'::jsonb), '{model_rate_limits}'::text[], COALESCE(extra->'model_rate_limits', '{}'::jsonb), true),
|
||||
ARRAY['model_rate_limits', $1]::text[],
|
||||
$2::jsonb,
|
||||
true
|
||||
),
|
||||
updated_at = NOW()
|
||||
WHERE id = $3 AND deleted_at IS NULL`,
|
||||
scope,
|
||||
raw,
|
||||
id,
|
||||
)
|
||||
|
||||
95
backend/internal/repository/aes_encryptor.go
Normal file
95
backend/internal/repository/aes_encryptor.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// AESEncryptor implements SecretEncryptor using AES-256-GCM
|
||||
type AESEncryptor struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewAESEncryptor creates a new AES encryptor
|
||||
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
|
||||
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
|
||||
}
|
||||
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
|
||||
}
|
||||
|
||||
return &AESEncryptor{key: key}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext using AES-256-GCM
|
||||
// Output format: base64(nonce + ciphertext + tag)
|
||||
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create gcm: %w", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the plaintext
|
||||
// Seal appends the ciphertext and tag to the nonce
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
|
||||
// Encode as base64
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext using AES-256-GCM
|
||||
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||
// Decode from base64
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode base64: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create gcm: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
83
backend/internal/repository/announcement_read_repo.go
Normal file
83
backend/internal/repository/announcement_read_repo.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcementread"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type announcementReadRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementReadRepository {
|
||||
return &announcementReadRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return client.AnnouncementRead.Create().
|
||||
SetAnnouncementID(announcementID).
|
||||
SetUserID(userID).
|
||||
SetReadAt(readAt).
|
||||
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
|
||||
DoNothing().
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
|
||||
if len(announcementIDs) == 0 {
|
||||
return map[int64]time.Time{}, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.AnnouncementRead.Query().
|
||||
Where(
|
||||
announcementread.UserIDEQ(userID),
|
||||
announcementread.AnnouncementIDIn(announcementIDs...),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[int64]time.Time, len(rows))
|
||||
for i := range rows {
|
||||
out[rows[i].AnnouncementID] = rows[i].ReadAt
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return map[int64]time.Time{}, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.AnnouncementRead.Query().
|
||||
Where(
|
||||
announcementread.AnnouncementIDEQ(announcementID),
|
||||
announcementread.UserIDIn(userIDs...),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[int64]time.Time, len(rows))
|
||||
for i := range rows {
|
||||
out[rows[i].UserID] = rows[i].ReadAt
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error) {
|
||||
count, err := r.client.AnnouncementRead.Query().
|
||||
Where(announcementread.AnnouncementIDEQ(announcementID)).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
194
backend/internal/repository/announcement_repo.go
Normal file
194
backend/internal/repository/announcement_repo.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcement"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type announcementRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAnnouncementRepository(client *dbent.Client) service.AnnouncementRepository {
|
||||
return &announcementRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Create(ctx context.Context, a *service.Announcement) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.Announcement.Create().
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
builder.SetStartsAt(*a.StartsAt)
|
||||
}
|
||||
if a.EndsAt != nil {
|
||||
builder.SetEndsAt(*a.EndsAt)
|
||||
}
|
||||
if a.CreatedBy != nil {
|
||||
builder.SetCreatedBy(*a.CreatedBy)
|
||||
}
|
||||
if a.UpdatedBy != nil {
|
||||
builder.SetUpdatedBy(*a.UpdatedBy)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applyAnnouncementEntityToService(a, created)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
|
||||
m, err := r.client.Announcement.Query().
|
||||
Where(announcement.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
|
||||
}
|
||||
return announcementEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Update(ctx context.Context, a *service.Announcement) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.Announcement.UpdateOneID(a.ID).
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
builder.SetStartsAt(*a.StartsAt)
|
||||
} else {
|
||||
builder.ClearStartsAt()
|
||||
}
|
||||
if a.EndsAt != nil {
|
||||
builder.SetEndsAt(*a.EndsAt)
|
||||
} else {
|
||||
builder.ClearEndsAt()
|
||||
}
|
||||
if a.CreatedBy != nil {
|
||||
builder.SetCreatedBy(*a.CreatedBy)
|
||||
} else {
|
||||
builder.ClearCreatedBy()
|
||||
}
|
||||
if a.UpdatedBy != nil {
|
||||
builder.SetUpdatedBy(*a.UpdatedBy)
|
||||
} else {
|
||||
builder.ClearUpdatedBy()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
|
||||
}
|
||||
|
||||
a.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.Announcement.Delete().Where(announcement.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *announcementRepository) List(
|
||||
ctx context.Context,
|
||||
params pagination.PaginationParams,
|
||||
filters service.AnnouncementListFilters,
|
||||
) ([]service.Announcement, *pagination.PaginationResult, error) {
|
||||
q := r.client.Announcement.Query()
|
||||
|
||||
if filters.Status != "" {
|
||||
q = q.Where(announcement.StatusEQ(filters.Status))
|
||||
}
|
||||
if filters.Search != "" {
|
||||
q = q.Where(
|
||||
announcement.Or(
|
||||
announcement.TitleContainsFold(filters.Search),
|
||||
announcement.ContentContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
items, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(announcement.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
out := announcementEntitiesToService(items)
|
||||
return out, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
|
||||
q := r.client.Announcement.Query().
|
||||
Where(
|
||||
announcement.StatusEQ(service.AnnouncementStatusActive),
|
||||
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
|
||||
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
|
||||
).
|
||||
Order(dbent.Desc(announcement.FieldID))
|
||||
|
||||
items, err := q.All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return announcementEntitiesToService(items), nil
|
||||
}
|
||||
|
||||
func applyAnnouncementEntityToService(dst *service.Announcement, src *dbent.Announcement) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Announcement{
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func announcementEntitiesToService(models []*dbent.Announcement) []service.Announcement {
|
||||
out := make([]service.Announcement, 0, len(models))
|
||||
for i := range models {
|
||||
if s := announcementEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -391,17 +391,20 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,13 +9,27 @@ import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
const (
|
||||
verifyCodeKeyPrefix = "verify_code:"
|
||||
passwordResetKeyPrefix = "password_reset:"
|
||||
passwordResetSentAtKeyPrefix = "password_reset_sent:"
|
||||
)
|
||||
|
||||
// verifyCodeKey generates the Redis key for email verification code.
|
||||
func verifyCodeKey(email string) string {
|
||||
return verifyCodeKeyPrefix + email
|
||||
}
|
||||
|
||||
// passwordResetKey generates the Redis key for password reset token.
|
||||
func passwordResetKey(email string) string {
|
||||
return passwordResetKeyPrefix + email
|
||||
}
|
||||
|
||||
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
||||
func passwordResetSentAtKey(email string) string {
|
||||
return passwordResetSentAtKeyPrefix + email
|
||||
}
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
|
||||
key := verifyCodeKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Password reset token methods
|
||||
|
||||
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
|
||||
key := passwordResetKey(email)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data service.PasswordResetTokenData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
|
||||
key := passwordResetKey(email)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||
key := passwordResetKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Password reset email cooldown methods
|
||||
|
||||
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||
key := passwordResetSentAtKey(email)
|
||||
exists, err := c.rdb.Exists(ctx, key).Result()
|
||||
return err == nil && exists > 0
|
||||
}
|
||||
|
||||
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||
key := passwordResetSentAtKey(email)
|
||||
return c.rdb.Set(ctx, key, "1", ttl).Err()
|
||||
}
|
||||
|
||||
@@ -433,3 +433,61 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
||||
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||
func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(
|
||||
ctx,
|
||||
"SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id",
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var accountIDs []int64
|
||||
for rows.Next() {
|
||||
var accountID int64
|
||||
if err := rows.Scan(&accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accountIDs = append(accountIDs, accountID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return accountIDs, nil
|
||||
}
|
||||
|
||||
// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定)
|
||||
func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
if len(accountIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定
|
||||
_, err := r.sql.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO account_groups (account_id, group_id, priority, created_at)
|
||||
SELECT unnest($1::bigint[]), $2, 50, NOW()
|
||||
ON CONFLICT (account_id, group_id) DO NOTHING`,
|
||||
pq.Array(accountIDs),
|
||||
groupID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送调度器事件
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,11 +2,11 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/imroc/req/v3"
|
||||
@@ -22,7 +22,7 @@ type openaiOAuthService struct {
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(s.tokenURL, proxyURL)
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
@@ -39,23 +39,24 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(s.tokenURL, proxyURL)
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
@@ -67,29 +68,25 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
|
||||
forceHTTP2 := false
|
||||
if parsedURL, err := url.Parse(tokenURL); err == nil {
|
||||
forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
|
||||
}
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 120 * time.Second,
|
||||
ForceHTTP2: forceHTTP2,
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 120 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||
}
|
||||
return &proxyProbeService{
|
||||
ipInfoURL: defaultIPInfoURL,
|
||||
insecureSkipVerify: insecure,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
validateResolvedIP: validateResolvedIP,
|
||||
@@ -36,12 +35,20 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
}
|
||||
|
||||
const (
|
||||
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// probeURLs 按优先级排列的探测 URL 列表
|
||||
// 某些 AI API 专用代理只允许访问特定域名,因此需要多个备选
|
||||
var probeURLs = []struct {
|
||||
url string
|
||||
parser string // "ip-api" or "httpbin"
|
||||
}{
|
||||
{"http://ip-api.com/json/?lang=zh-CN", "ip-api"},
|
||||
{"http://httpbin.org/ip", "httpbin"},
|
||||
}
|
||||
|
||||
type proxyProbeService struct {
|
||||
ipInfoURL string
|
||||
insecureSkipVerify bool
|
||||
allowPrivateHosts bool
|
||||
validateResolvedIP bool
|
||||
@@ -60,8 +67,21 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, probe := range probeURLs {
|
||||
exitInfo, latencyMs, err := s.probeWithURL(ctx, client, probe.url, probe.parser)
|
||||
if err == nil {
|
||||
return exitInfo, latencyMs, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("all probe URLs failed, last error: %w", lastErr)
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Client, url string, parser string) (*service.ProxyExitInfo, int64, error) {
|
||||
startTime := time.Now()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
@@ -78,6 +98,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
switch parser {
|
||||
case "ip-api":
|
||||
return s.parseIPAPI(body, latencyMs)
|
||||
case "httpbin":
|
||||
return s.parseHTTPBin(body, latencyMs)
|
||||
default:
|
||||
return nil, latencyMs, fmt.Errorf("unknown parser: %s", parser)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) parseIPAPI(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
|
||||
var ipInfo struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
@@ -89,13 +125,12 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
CountryCode string `json:"countryCode"`
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
||||
preview := string(body)
|
||||
if len(preview) > 200 {
|
||||
preview = preview[:200] + "..."
|
||||
}
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w (body: %s)", err, preview)
|
||||
}
|
||||
if strings.ToLower(ipInfo.Status) != "success" {
|
||||
if ipInfo.Message == "" {
|
||||
@@ -116,3 +151,19 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
CountryCode: ipInfo.CountryCode,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) parseHTTPBin(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
|
||||
// httpbin.org/ip 返回格式: {"origin": "1.2.3.4"}
|
||||
var result struct {
|
||||
Origin string `json:"origin"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse httpbin response: %w", err)
|
||||
}
|
||||
if result.Origin == "" {
|
||||
return nil, latencyMs, fmt.Errorf("httpbin: no IP found in response")
|
||||
}
|
||||
return &service.ProxyExitInfo{
|
||||
IP: result.Origin,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -21,7 +22,6 @@ type ProxyProbeServiceSuite struct {
|
||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.prober = &proxyProbeService{
|
||||
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
}
|
||||
@@ -49,12 +49,16 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
|
||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
seen := make(chan string, 1)
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_IPAPI() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
||||
// 检查是否是 ip-api 请求
|
||||
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
||||
return
|
||||
}
|
||||
// 其他请求返回错误
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
@@ -65,45 +69,59 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
require.Equal(s.T(), "r", info.Region)
|
||||
require.Equal(s.T(), "cc", info.Country)
|
||||
require.Equal(s.T(), "CC", info.CountryCode)
|
||||
|
||||
// Verify proxy received the request
|
||||
select {
|
||||
case uri := <-seen:
|
||||
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
|
||||
default:
|
||||
require.Fail(s.T(), "expected proxy to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_HTTPBinFallback() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// ip-api 失败
|
||||
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
// httpbin 成功
|
||||
if strings.Contains(r.RequestURI, "httpbin.org") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"origin": "5.6.7.8"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.NoError(s.T(), err, "ProbeProxy should fallback to httpbin")
|
||||
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
|
||||
require.Equal(s.T(), "5.6.7.8", info.IP)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_AllFailed() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status: 503")
|
||||
require.ErrorContains(s.T(), err, "all probe URLs failed")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
return
|
||||
}
|
||||
// httpbin 也返回无效响应
|
||||
if strings.Contains(r.RequestURI, "httpbin.org") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "failed to parse response")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
|
||||
s.prober.ipInfoURL = "://invalid-url"
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
|
||||
require.ErrorContains(s.T(), err, "all probe URLs failed")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
||||
@@ -114,6 +132,40 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
||||
require.Error(s.T(), err, "expected error when proxy server is closed")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Success() {
|
||||
body := []byte(`{"status":"success","query":"1.2.3.4","city":"Beijing","regionName":"Beijing","country":"China","countryCode":"CN"}`)
|
||||
info, latencyMs, err := s.prober.parseIPAPI(body, 100)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), int64(100), latencyMs)
|
||||
require.Equal(s.T(), "1.2.3.4", info.IP)
|
||||
require.Equal(s.T(), "Beijing", info.City)
|
||||
require.Equal(s.T(), "Beijing", info.Region)
|
||||
require.Equal(s.T(), "China", info.Country)
|
||||
require.Equal(s.T(), "CN", info.CountryCode)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Failure() {
|
||||
body := []byte(`{"status":"fail","message":"rate limited"}`)
|
||||
_, _, err := s.prober.parseIPAPI(body, 100)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "rate limited")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_Success() {
|
||||
body := []byte(`{"origin": "9.8.7.6"}`)
|
||||
info, latencyMs, err := s.prober.parseHTTPBin(body, 50)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), int64(50), latencyMs)
|
||||
require.Equal(s.T(), "9.8.7.6", info.IP)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_NoIP() {
|
||||
body := []byte(`{"origin": ""}`)
|
||||
_, _, err := s.prober.parseHTTPBin(body, 50)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "no IP found")
|
||||
}
|
||||
|
||||
func TestProxyProbeServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProxyProbeServiceSuite))
|
||||
}
|
||||
|
||||
@@ -202,6 +202,57 @@ func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, lim
|
||||
return redeemCodeEntitiesToService(codes), nil
|
||||
}
|
||||
|
||||
// ListByUserPaginated returns paginated balance/concurrency history for a user.
|
||||
// Supports optional type filter (e.g. "balance", "admin_balance", "concurrency", "admin_concurrency", "subscription").
|
||||
func (r *redeemCodeRepository) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.UsedByEQ(userID))
|
||||
|
||||
// Optional type filter
|
||||
if codeType != "" {
|
||||
q = q.Where(redeemcode.TypeEQ(codeType))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(redeemcode.FieldUsedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return redeemCodeEntitiesToService(codes), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SumPositiveBalanceByUser returns total recharged amount (sum of value > 0 where type is balance/admin_balance).
|
||||
func (r *redeemCodeRepository) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
|
||||
var result []struct {
|
||||
Sum float64 `json:"sum"`
|
||||
}
|
||||
err := r.client.RedeemCode.Query().
|
||||
Where(
|
||||
redeemcode.UsedByEQ(userID),
|
||||
redeemcode.ValueGT(0),
|
||||
redeemcode.TypeIn("balance", "admin_balance"),
|
||||
).
|
||||
Aggregate(dbent.As(dbent.Sum(redeemcode.FieldValue), "sum")).
|
||||
Scan(ctx, &result)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return result[0].Sum, nil
|
||||
}
|
||||
|
||||
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -26,7 +27,7 @@ func InitRedis(cfg *config.Config) *redis.Client {
|
||||
// buildRedisOptions 构建 Redis 连接选项
|
||||
// 从配置文件读取连接池和超时参数,支持生产环境调优
|
||||
func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
return &redis.Options{
|
||||
opts := &redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
@@ -36,4 +37,13 @@ func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
PoolSize: cfg.Redis.PoolSize, // 连接池大小
|
||||
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
|
||||
}
|
||||
|
||||
if cfg.Redis.EnableTLS {
|
||||
opts.TLSConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: cfg.Redis.Host,
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
@@ -32,4 +32,16 @@ func TestBuildRedisOptions(t *testing.T) {
|
||||
require.Equal(t, 4*time.Second, opts.WriteTimeout)
|
||||
require.Equal(t, 100, opts.PoolSize)
|
||||
require.Equal(t, 10, opts.MinIdleConns)
|
||||
require.Nil(t, opts.TLSConfig)
|
||||
|
||||
// Test case with TLS enabled
|
||||
cfgTLS := &config.Config{
|
||||
Redis: config.RedisConfig{
|
||||
Host: "localhost",
|
||||
EnableTLS: true,
|
||||
},
|
||||
}
|
||||
optsTLS := buildRedisOptions(cfgTLS)
|
||||
require.NotNil(t, optsTLS.TLSConfig)
|
||||
require.Equal(t, "localhost", optsTLS.TLSConfig.ServerName)
|
||||
}
|
||||
|
||||
@@ -77,21 +77,9 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
|
||||
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
|
||||
}
|
||||
|
||||
func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
|
||||
require.Equal(t, "2", forceHTTPVersion(t, client))
|
||||
}
|
||||
|
||||
func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
|
||||
require.Equal(t, "", forceHTTPVersion(t, client))
|
||||
}
|
||||
|
||||
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
|
||||
client := createOpenAIReqClient("http://proxy.local:8080")
|
||||
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
return nil, false, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return []*service.Account{}, true, nil
|
||||
// 空快照视为缓存未命中,触发数据库回退查询
|
||||
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(ids))
|
||||
|
||||
149
backend/internal/repository/totp_cache.go
Normal file
149
backend/internal/repository/totp_cache.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
totpSetupKeyPrefix = "totp:setup:"
|
||||
totpLoginKeyPrefix = "totp:login:"
|
||||
totpAttemptsKeyPrefix = "totp:attempts:"
|
||||
totpAttemptsTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
// TotpCache implements service.TotpCache using Redis
|
||||
type TotpCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewTotpCache creates a new TOTP cache
|
||||
func NewTotpCache(rdb *redis.Client) service.TotpCache {
|
||||
return &TotpCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// GetSetupSession retrieves a TOTP setup session
|
||||
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
data, err := c.rdb.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get setup session: %w", err)
|
||||
}
|
||||
|
||||
var session service.TotpSetupSession
|
||||
if err := json.Unmarshal(data, &session); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal setup session: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// SetSetupSession stores a TOTP setup session
|
||||
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal setup session: %w", err)
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
|
||||
return fmt.Errorf("set setup session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSetupSession deletes a TOTP setup session
|
||||
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// GetLoginSession retrieves a TOTP login session
|
||||
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
data, err := c.rdb.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get login session: %w", err)
|
||||
}
|
||||
|
||||
var session service.TotpLoginSession
|
||||
if err := json.Unmarshal(data, &session); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal login session: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// SetLoginSession stores a TOTP login session
|
||||
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal login session: %w", err)
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
|
||||
return fmt.Errorf("set login session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteLoginSession deletes a TOTP login session
|
||||
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// IncrementVerifyAttempts increments the verify attempt counter
|
||||
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
|
||||
// Use pipeline for atomic increment and set TTL
|
||||
pipe := c.rdb.Pipeline()
|
||||
incrCmd := pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, totpAttemptsTTL)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, fmt.Errorf("increment verify attempts: %w", err)
|
||||
}
|
||||
|
||||
count, err := incrCmd.Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get increment result: %w", err)
|
||||
}
|
||||
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
// GetVerifyAttempts gets the current verify attempt count
|
||||
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, fmt.Errorf("get verify attempts: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ClearVerifyAttempts clears the verify attempt counter
|
||||
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at"
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
@@ -115,6 +115,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
reasoning_effort,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
@@ -122,7 +123,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -136,6 +137,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
mediaType := nullString(log.MediaType)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
|
||||
var requestIDArg any
|
||||
if requestID != "" {
|
||||
@@ -173,6 +175,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
mediaType,
|
||||
reasoningEffort,
|
||||
createdAt,
|
||||
}
|
||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
@@ -2094,6 +2097,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
mediaType sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
createdAt time.Time
|
||||
)
|
||||
|
||||
@@ -2129,6 +2133,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&mediaType,
|
||||
&reasoningEffort,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
@@ -2191,6 +2196,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if mediaType.Valid {
|
||||
log.MediaType = &mediaType.String
|
||||
}
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
|
||||
return log, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -189,6 +190,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
dbuser.Or(
|
||||
dbuser.EmailContainsFold(filters.Search),
|
||||
dbuser.UsernameContainsFold(filters.Search),
|
||||
dbuser.NotesContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -466,3 +468,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
|
||||
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
update := client.User.UpdateOneID(userID)
|
||||
if encryptedSecret == nil {
|
||||
update = update.ClearTotpSecretEncrypted()
|
||||
} else {
|
||||
update = update.SetTotpSecretEncrypted(*encryptedSecret)
|
||||
}
|
||||
_, err := update.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableTotp 启用用户的 TOTP 双因素认证
|
||||
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.User.UpdateOneID(userID).
|
||||
SetTotpEnabled(true).
|
||||
SetTotpEnabledAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableTotp 禁用用户的 TOTP 双因素认证
|
||||
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.User.UpdateOneID(userID).
|
||||
SetTotpEnabled(false).
|
||||
ClearTotpEnabledAt().
|
||||
ClearTotpSecretEncrypted().
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
q := client.UserSubscription.Query()
|
||||
if userID != nil {
|
||||
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
if groupID != nil {
|
||||
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
||||
}
|
||||
if status != "" {
|
||||
|
||||
// Status filtering with real-time expiration check
|
||||
now := time.Now()
|
||||
switch status {
|
||||
case service.SubscriptionStatusActive:
|
||||
// Active: status is active AND not yet expired
|
||||
q = q.Where(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtGT(now),
|
||||
)
|
||||
case service.SubscriptionStatusExpired:
|
||||
// Expired: status is expired OR (status is active but already expired)
|
||||
q = q.Where(
|
||||
usersubscription.Or(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
|
||||
usersubscription.And(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtLTE(now),
|
||||
),
|
||||
),
|
||||
)
|
||||
case "":
|
||||
// No filter
|
||||
default:
|
||||
// Other status (e.g., revoked)
|
||||
q = q.Where(usersubscription.StatusEQ(status))
|
||||
}
|
||||
|
||||
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
q = q.WithUser().WithGroup().WithAssignedByUser()
|
||||
|
||||
// Determine sort field
|
||||
var field string
|
||||
switch sortBy {
|
||||
case "expires_at":
|
||||
field = usersubscription.FieldExpiresAt
|
||||
case "status":
|
||||
field = usersubscription.FieldStatus
|
||||
default:
|
||||
field = usersubscription.FieldCreatedAt
|
||||
}
|
||||
|
||||
// Determine sort order (default: desc)
|
||||
if sortOrder == "asc" && sortBy != "" {
|
||||
q = q.Order(dbent.Asc(field))
|
||||
} else {
|
||||
q = q.Order(dbent.Desc(field))
|
||||
}
|
||||
|
||||
subs, err := q.
|
||||
WithUser().
|
||||
WithGroup().
|
||||
WithAssignedByUser().
|
||||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
All(ctx)
|
||||
|
||||
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
group := s.mustCreateGroup("g-list")
|
||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
||||
|
||||
@@ -57,6 +57,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
@@ -83,6 +85,10 @@ var ProviderSet = wire.NewSet(
|
||||
NewSchedulerCache,
|
||||
NewSchedulerOutboxRepository,
|
||||
NewProxyLatencyCache,
|
||||
NewTotpCache,
|
||||
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
|
||||
@@ -201,7 +201,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
UserID: 1,
|
||||
GroupID: 10,
|
||||
StartsAt: deps.now,
|
||||
ExpiresAt: deps.now.Add(24 * time.Hour),
|
||||
ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
|
||||
Status: service.SubscriptionStatusActive,
|
||||
DailyUsageUSD: 1.23,
|
||||
WeeklyUsageUSD: 2.34,
|
||||
@@ -226,7 +226,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"user_id": 1,
|
||||
"group_id": 10,
|
||||
"starts_at": "2025-01-02T03:04:05Z",
|
||||
"expires_at": "2025-01-03T03:04:05Z",
|
||||
"expires_at": "2099-01-02T03:04:05Z",
|
||||
"status": "active",
|
||||
"daily_window_start": null,
|
||||
"weekly_window_start": null,
|
||||
@@ -457,6 +457,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
"registration_enabled": true,
|
||||
"email_verify_enabled": false,
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"smtp_host": "smtp.example.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_username": "user",
|
||||
@@ -490,8 +493,11 @@ func TestAPIContracts(t *testing.T) {
|
||||
"fallback_model_openai": "gpt-4o",
|
||||
"enable_identity_patch": true,
|
||||
"identity_patch_prompt": "",
|
||||
"invitation_code_enabled": false,
|
||||
"home_content": "",
|
||||
"hide_ccs_import_button": false
|
||||
"hide_ccs_import_button": false,
|
||||
"purchase_subscription_enabled": false,
|
||||
"purchase_subscription_url": ""
|
||||
}
|
||||
}`,
|
||||
},
|
||||
@@ -600,7 +606,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||
@@ -759,6 +765,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubApiKeyCache struct{}
|
||||
|
||||
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
@@ -868,6 +886,14 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubAccountRepo struct {
|
||||
bulkUpdateIDs []int64
|
||||
}
|
||||
@@ -1133,6 +1159,14 @@ func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit
|
||||
return append([]service.RedeemCode(nil), codes...), nil
|
||||
}
|
||||
|
||||
func (stubRedeemCodeRepo) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubRedeemCodeRepo) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct {
|
||||
byUser map[int64][]service.UserSubscription
|
||||
activeByUser map[int64][]service.UserSubscription
|
||||
@@ -1185,7 +1219,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
|
||||
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
|
||||
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ func RegisterAdminRoutes(
|
||||
// 账号管理
|
||||
registerAccountRoutes(admin, h)
|
||||
|
||||
// 公告管理
|
||||
registerAnnouncementRoutes(admin, h)
|
||||
|
||||
// OpenAI OAuth
|
||||
registerOpenAIOAuthRoutes(admin, h)
|
||||
|
||||
@@ -172,6 +175,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
||||
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
|
||||
|
||||
// User attribute values
|
||||
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
||||
@@ -229,6 +233,18 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerAnnouncementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
announcements := admin.Group("/announcements")
|
||||
{
|
||||
announcements.GET("", h.Admin.Announcement.List)
|
||||
announcements.POST("", h.Admin.Announcement.Create)
|
||||
announcements.GET("/:id", h.Admin.Announcement.GetByID)
|
||||
announcements.PUT("/:id", h.Admin.Announcement.Update)
|
||||
announcements.DELETE("/:id", h.Admin.Announcement.Delete)
|
||||
announcements.GET("/:id/read-status", h.Admin.Announcement.ListReadStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
openai := admin.Group("/openai")
|
||||
{
|
||||
|
||||
@@ -26,11 +26,24 @@ func RegisterAuthRoutes(
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ValidatePromoCode)
|
||||
// 邀请码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||
auth.POST("/validate-invitation-code", rateLimiter.LimitWithOptions("validate-invitation", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ValidateInvitationCode)
|
||||
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
|
||||
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ForgotPassword)
|
||||
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||
auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ResetPassword)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,17 @@ func RegisterUserRoutes(
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
user.PUT("", h.User.UpdateProfile)
|
||||
|
||||
// TOTP 双因素认证
|
||||
totp := user.Group("/totp")
|
||||
{
|
||||
totp.GET("/status", h.Totp.GetStatus)
|
||||
totp.GET("/verification-method", h.Totp.GetVerificationMethod)
|
||||
totp.POST("/send-code", h.Totp.SendVerifyCode)
|
||||
totp.POST("/setup", h.Totp.InitiateSetup)
|
||||
totp.POST("/enable", h.Totp.Enable)
|
||||
totp.POST("/disable", h.Totp.Disable)
|
||||
}
|
||||
}
|
||||
|
||||
// API Key管理
|
||||
@@ -53,6 +64,13 @@ func RegisterUserRoutes(
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
|
||||
}
|
||||
|
||||
// 公告(用户可见)
|
||||
announcements := authenticated.Group("/announcements")
|
||||
{
|
||||
announcements.GET("", h.Announcement.List)
|
||||
announcements.POST("/:id/read", h.Announcement.MarkRead)
|
||||
}
|
||||
|
||||
// 卡密兑换
|
||||
redeem := authenticated.Group("/redeem")
|
||||
{
|
||||
|
||||
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) GetClaudeUserID() string {
|
||||
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
|
||||
@@ -124,7 +124,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
|
||||
"system": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
"text": claudeCodeSystemPrompt,
|
||||
"cache_control": map[string]string{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
|
||||
@@ -22,6 +22,10 @@ type AdminService interface {
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||
// codeType is optional - pass empty string to return all types.
|
||||
// Also returns totalRecharged (sum of all positive balance top-ups).
|
||||
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
|
||||
|
||||
// Group management
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
|
||||
@@ -115,6 +119,8 @@ type CreateGroupInput struct {
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
|
||||
type UpdateGroupInput struct {
|
||||
@@ -142,6 +148,8 @@ type UpdateGroupInput struct {
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
@@ -535,6 +543,21 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
// Aggregate total recharged amount (only once, regardless of type filter)
|
||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return codes, result.Total, totalRecharged, nil
|
||||
}
|
||||
|
||||
// Group management implementations
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
@@ -589,6 +612,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
}
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
||||
var accountIDsToCopy []int64
|
||||
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||
// 去重源分组 IDs
|
||||
seen := make(map[int64]struct{})
|
||||
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
|
||||
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
|
||||
if _, exists := seen[srcGroupID]; !exists {
|
||||
seen[srcGroupID] = struct{}{}
|
||||
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
// 校验源分组的平台是否与新分组一致
|
||||
for _, srcGroupID := range uniqueSourceGroupIDs {
|
||||
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
|
||||
}
|
||||
if srcGroup.Platform != platform {
|
||||
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有源分组的账号(去重)
|
||||
var err error
|
||||
accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
@@ -614,6 +669,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果有需要复制的账号,绑定到新分组
|
||||
if len(accountIDsToCopy) > 0 {
|
||||
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
|
||||
return nil, fmt.Errorf("failed to bind accounts to new group: %w", err)
|
||||
}
|
||||
group.AccountCount = int64(len(accountIDsToCopy))
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -761,6 +825,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
|
||||
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||
// 去重源分组 IDs
|
||||
seen := make(map[int64]struct{})
|
||||
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
|
||||
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
|
||||
// 校验:源分组不能是自身
|
||||
if srcGroupID == id {
|
||||
return nil, fmt.Errorf("cannot copy accounts from self")
|
||||
}
|
||||
// 去重
|
||||
if _, exists := seen[srcGroupID]; !exists {
|
||||
seen[srcGroupID] = struct{}{}
|
||||
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
// 校验源分组的平台是否与当前分组一致
|
||||
for _, srcGroupID := range uniqueSourceGroupIDs {
|
||||
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
|
||||
}
|
||||
if srcGroup.Platform != group.Platform {
|
||||
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有源分组的账号(去重)
|
||||
accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
|
||||
}
|
||||
|
||||
// 先清空当前分组的所有账号绑定
|
||||
if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil {
|
||||
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
|
||||
}
|
||||
|
||||
// 再绑定源分组的账号
|
||||
if len(accountIDsToCopy) > 0 {
|
||||
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
|
||||
return nil, fmt.Errorf("failed to bind accounts to group: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
|
||||
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected EnableTotp call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected DisableTotp call")
|
||||
}
|
||||
|
||||
type groupRepoStub struct {
|
||||
affectedUserIDs []int64
|
||||
deleteErr error
|
||||
@@ -152,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
type proxyRepoStub struct {
|
||||
deleteErr error
|
||||
countErr error
|
||||
@@ -262,6 +282,14 @@ func (s *redeemRepoStub) ListByUser(ctx context.Context, userID int64, limit int
|
||||
panic("unexpected ListByUser call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserPaginated call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
|
||||
panic("unexpected SumPositiveBalanceByUser call")
|
||||
}
|
||||
|
||||
type subscriptionInvalidateCall struct {
|
||||
userID int64
|
||||
groupID int64
|
||||
|
||||
@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
||||
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
@@ -378,3 +386,11 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
@@ -152,6 +152,14 @@ func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params p
|
||||
return s.listWithFiltersCodes, result, nil
|
||||
}
|
||||
|
||||
func (s *redeemRepoStubForAdminList) ListByUserPaginated(_ context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserPaginated call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStubForAdminList) SumPositiveBalanceByUser(_ context.Context, userID int64) (float64, error) {
|
||||
panic("unexpected SumPositiveBalanceByUser call")
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &accountRepoStubForAdminList{
|
||||
|
||||
64
backend/internal/service/announcement.go
Normal file
64
backend/internal/service/announcement.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementStatusDraft = domain.AnnouncementStatusDraft
|
||||
AnnouncementStatusActive = domain.AnnouncementStatusActive
|
||||
AnnouncementStatusArchived = domain.AnnouncementStatusArchived
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
|
||||
AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementOperatorIn = domain.AnnouncementOperatorIn
|
||||
AnnouncementOperatorGT = domain.AnnouncementOperatorGT
|
||||
AnnouncementOperatorGTE = domain.AnnouncementOperatorGTE
|
||||
AnnouncementOperatorLT = domain.AnnouncementOperatorLT
|
||||
AnnouncementOperatorLTE = domain.AnnouncementOperatorLTE
|
||||
AnnouncementOperatorEQ = domain.AnnouncementOperatorEQ
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
|
||||
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
|
||||
)
|
||||
|
||||
type AnnouncementTargeting = domain.AnnouncementTargeting
|
||||
|
||||
type AnnouncementConditionGroup = domain.AnnouncementConditionGroup
|
||||
|
||||
type AnnouncementCondition = domain.AnnouncementCondition
|
||||
|
||||
type Announcement = domain.Announcement
|
||||
|
||||
type AnnouncementListFilters struct {
|
||||
Status string
|
||||
Search string
|
||||
}
|
||||
|
||||
type AnnouncementRepository interface {
|
||||
Create(ctx context.Context, a *Announcement) error
|
||||
GetByID(ctx context.Context, id int64) (*Announcement, error)
|
||||
Update(ctx context.Context, a *Announcement) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context, now time.Time) ([]Announcement, error)
|
||||
}
|
||||
|
||||
type AnnouncementReadRepository interface {
|
||||
MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error
|
||||
GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error)
|
||||
GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error)
|
||||
CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error)
|
||||
}
|
||||
378
backend/internal/service/announcement_service.go
Normal file
378
backend/internal/service/announcement_service.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type AnnouncementService struct {
|
||||
announcementRepo AnnouncementRepository
|
||||
readRepo AnnouncementReadRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
}
|
||||
|
||||
func NewAnnouncementService(
|
||||
announcementRepo AnnouncementRepository,
|
||||
readRepo AnnouncementReadRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
) *AnnouncementService {
|
||||
return &AnnouncementService{
|
||||
announcementRepo: announcementRepo,
|
||||
readRepo: readRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
}
|
||||
}
|
||||
|
||||
type CreateAnnouncementInput struct {
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
}
|
||||
|
||||
type UpdateAnnouncementInput struct {
|
||||
Title *string
|
||||
Content *string
|
||||
Status *string
|
||||
Targeting *AnnouncementTargeting
|
||||
StartsAt **time.Time
|
||||
EndsAt **time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
}
|
||||
|
||||
type UserAnnouncement struct {
|
||||
Announcement Announcement
|
||||
ReadAt *time.Time
|
||||
}
|
||||
|
||||
type AnnouncementUserReadStatus struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Balance float64 `json:"balance"`
|
||||
Eligible bool `json:"eligible"`
|
||||
ReadAt *time.Time `json:"read_at,omitempty"`
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("create announcement: nil input")
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(input.Title)
|
||||
content := strings.TrimSpace(input.Content)
|
||||
if title == "" || len(title) > 200 {
|
||||
return nil, fmt.Errorf("create announcement: invalid title")
|
||||
}
|
||||
if content == "" {
|
||||
return nil, fmt.Errorf("create announcement: content is required")
|
||||
}
|
||||
|
||||
status := strings.TrimSpace(input.Status)
|
||||
if status == "" {
|
||||
status = AnnouncementStatusDraft
|
||||
}
|
||||
if !isValidAnnouncementStatus(status) {
|
||||
return nil, fmt.Errorf("create announcement: invalid status")
|
||||
}
|
||||
|
||||
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.StartsAt != nil && input.EndsAt != nil {
|
||||
if !input.StartsAt.Before(*input.EndsAt) {
|
||||
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
|
||||
}
|
||||
}
|
||||
|
||||
a := &Announcement{
|
||||
Title: title,
|
||||
Content: content,
|
||||
Status: status,
|
||||
Targeting: targeting,
|
||||
StartsAt: input.StartsAt,
|
||||
EndsAt: input.EndsAt,
|
||||
}
|
||||
if input.ActorID != nil && *input.ActorID > 0 {
|
||||
a.CreatedBy = input.ActorID
|
||||
a.UpdatedBy = input.ActorID
|
||||
}
|
||||
|
||||
if err := s.announcementRepo.Create(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("create announcement: %w", err)
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("update announcement: nil input")
|
||||
}
|
||||
|
||||
a, err := s.announcementRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.Title != nil {
|
||||
title := strings.TrimSpace(*input.Title)
|
||||
if title == "" || len(title) > 200 {
|
||||
return nil, fmt.Errorf("update announcement: invalid title")
|
||||
}
|
||||
a.Title = title
|
||||
}
|
||||
if input.Content != nil {
|
||||
content := strings.TrimSpace(*input.Content)
|
||||
if content == "" {
|
||||
return nil, fmt.Errorf("update announcement: content is required")
|
||||
}
|
||||
a.Content = content
|
||||
}
|
||||
if input.Status != nil {
|
||||
status := strings.TrimSpace(*input.Status)
|
||||
if !isValidAnnouncementStatus(status) {
|
||||
return nil, fmt.Errorf("update announcement: invalid status")
|
||||
}
|
||||
a.Status = status
|
||||
}
|
||||
|
||||
if input.Targeting != nil {
|
||||
targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.Targeting = targeting
|
||||
}
|
||||
|
||||
if input.StartsAt != nil {
|
||||
a.StartsAt = *input.StartsAt
|
||||
}
|
||||
if input.EndsAt != nil {
|
||||
a.EndsAt = *input.EndsAt
|
||||
}
|
||||
|
||||
if a.StartsAt != nil && a.EndsAt != nil {
|
||||
if !a.StartsAt.Before(*a.EndsAt) {
|
||||
return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
|
||||
}
|
||||
}
|
||||
|
||||
if input.ActorID != nil && *input.ActorID > 0 {
|
||||
a.UpdatedBy = input.ActorID
|
||||
}
|
||||
|
||||
if err := s.announcementRepo.Update(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("update announcement: %w", err)
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.announcementRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete announcement: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) GetByID(ctx context.Context, id int64) (*Announcement, error) {
|
||||
return s.announcementRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
|
||||
return s.announcementRepo.List(ctx, params, filters)
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) ListForUser(ctx context.Context, userID int64, unreadOnly bool) ([]UserAnnouncement, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
|
||||
for i := range activeSubs {
|
||||
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
anns, err := s.announcementRepo.ListActive(ctx, now)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active announcements: %w", err)
|
||||
}
|
||||
|
||||
visible := make([]Announcement, 0, len(anns))
|
||||
ids := make([]int64, 0, len(anns))
|
||||
for i := range anns {
|
||||
a := anns[i]
|
||||
if !a.IsActiveAt(now) {
|
||||
continue
|
||||
}
|
||||
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
|
||||
continue
|
||||
}
|
||||
visible = append(visible, a)
|
||||
ids = append(ids, a.ID)
|
||||
}
|
||||
|
||||
if len(visible) == 0 {
|
||||
return []UserAnnouncement{}, nil
|
||||
}
|
||||
|
||||
readMap, err := s.readRepo.GetReadMapByUser(ctx, userID, ids)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get read map: %w", err)
|
||||
}
|
||||
|
||||
out := make([]UserAnnouncement, 0, len(visible))
|
||||
for i := range visible {
|
||||
a := visible[i]
|
||||
readAt, ok := readMap[a.ID]
|
||||
if unreadOnly && ok {
|
||||
continue
|
||||
}
|
||||
var ptr *time.Time
|
||||
if ok {
|
||||
t := readAt
|
||||
ptr = &t
|
||||
}
|
||||
out = append(out, UserAnnouncement{
|
||||
Announcement: a,
|
||||
ReadAt: ptr,
|
||||
})
|
||||
}
|
||||
|
||||
// 未读优先、同状态按创建时间倒序
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
ai, aj := out[i], out[j]
|
||||
if (ai.ReadAt == nil) != (aj.ReadAt == nil) {
|
||||
return ai.ReadAt == nil
|
||||
}
|
||||
return ai.Announcement.ID > aj.Announcement.ID
|
||||
})
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) MarkRead(ctx context.Context, userID, announcementID int64) error {
|
||||
// 安全:仅允许标记当前用户“可见”的公告
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
a, err := s.announcementRepo.GetByID(ctx, announcementID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if !a.IsActiveAt(now) {
|
||||
return ErrAnnouncementNotFound
|
||||
}
|
||||
|
||||
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
|
||||
for i := range activeSubs {
|
||||
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
|
||||
}
|
||||
|
||||
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
|
||||
return ErrAnnouncementNotFound
|
||||
}
|
||||
|
||||
if err := s.readRepo.MarkRead(ctx, announcementID, userID, now); err != nil {
|
||||
return fmt.Errorf("mark read: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) ListUserReadStatus(
|
||||
ctx context.Context,
|
||||
announcementID int64,
|
||||
params pagination.PaginationParams,
|
||||
search string,
|
||||
) ([]AnnouncementUserReadStatus, *pagination.PaginationResult, error) {
|
||||
ann, err := s.announcementRepo.GetByID(ctx, announcementID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
filters := UserListFilters{
|
||||
Search: strings.TrimSpace(search),
|
||||
}
|
||||
|
||||
users, page, err := s.userRepo.ListWithFilters(ctx, params, filters)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list users: %w", err)
|
||||
}
|
||||
|
||||
userIDs := make([]int64, 0, len(users))
|
||||
for i := range users {
|
||||
userIDs = append(userIDs, users[i].ID)
|
||||
}
|
||||
|
||||
readMap, err := s.readRepo.GetReadMapByUsers(ctx, announcementID, userIDs)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get read map: %w", err)
|
||||
}
|
||||
|
||||
out := make([]AnnouncementUserReadStatus, 0, len(users))
|
||||
for i := range users {
|
||||
u := users[i]
|
||||
subs, err := s.userSubRepo.ListActiveByUserID(ctx, u.ID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
activeGroupIDs := make(map[int64]struct{}, len(subs))
|
||||
for j := range subs {
|
||||
activeGroupIDs[subs[j].GroupID] = struct{}{}
|
||||
}
|
||||
|
||||
readAt, ok := readMap[u.ID]
|
||||
var ptr *time.Time
|
||||
if ok {
|
||||
t := readAt
|
||||
ptr = &t
|
||||
}
|
||||
|
||||
out = append(out, AnnouncementUserReadStatus{
|
||||
UserID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Balance: u.Balance,
|
||||
Eligible: domain.AnnouncementTargeting(ann.Targeting).Matches(u.Balance, activeGroupIDs),
|
||||
ReadAt: ptr,
|
||||
})
|
||||
}
|
||||
|
||||
return out, page, nil
|
||||
}
|
||||
|
||||
func isValidAnnouncementStatus(status string) bool {
|
||||
switch status {
|
||||
case AnnouncementStatusDraft, AnnouncementStatusActive, AnnouncementStatusArchived:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
66
backend/internal/service/announcement_targeting_test.go
Normal file
66
backend/internal/service/announcement_targeting_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAnnouncementTargeting_Matches_EmptyMatchesAll(t *testing.T) {
|
||||
var targeting AnnouncementTargeting
|
||||
require.True(t, targeting.Matches(0, nil))
|
||||
require.True(t, targeting.Matches(123.45, map[int64]struct{}{1: {}}))
|
||||
}
|
||||
|
||||
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsEmptyGroup(t *testing.T) {
|
||||
targeting := AnnouncementTargeting{
|
||||
AnyOf: []AnnouncementConditionGroup{
|
||||
{AllOf: nil},
|
||||
},
|
||||
}
|
||||
_, err := targeting.NormalizeAndValidate()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
|
||||
}
|
||||
|
||||
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsInvalidCondition(t *testing.T) {
|
||||
targeting := AnnouncementTargeting{
|
||||
AnyOf: []AnnouncementConditionGroup{
|
||||
{
|
||||
AllOf: []AnnouncementCondition{
|
||||
{Type: "balance", Operator: "between", Value: 10},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := targeting.NormalizeAndValidate()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
|
||||
}
|
||||
|
||||
func TestAnnouncementTargeting_Matches_AndOrSemantics(t *testing.T) {
|
||||
targeting := AnnouncementTargeting{
|
||||
AnyOf: []AnnouncementConditionGroup{
|
||||
{
|
||||
AllOf: []AnnouncementCondition{
|
||||
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorGTE, Value: 100},
|
||||
{Type: AnnouncementConditionTypeSubscription, Operator: AnnouncementOperatorIn, GroupIDs: []int64{10}},
|
||||
},
|
||||
},
|
||||
{
|
||||
AllOf: []AnnouncementCondition{
|
||||
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorLT, Value: 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 命中第 2 组(balance < 5)
|
||||
require.True(t, targeting.Matches(4.99, nil))
|
||||
require.False(t, targeting.Matches(5, nil))
|
||||
|
||||
// 命中第 1 组(balance >= 100 AND 订阅 in [10])
|
||||
require.False(t, targeting.Matches(100, map[int64]struct{}{}))
|
||||
require.False(t, targeting.Matches(99.9, map[int64]struct{}{10: {}}))
|
||||
require.True(t, targeting.Matches(100, map[int64]struct{}{10: {}}))
|
||||
}
|
||||
@@ -273,13 +273,11 @@ func logPrefix(sessionID, accountName string) string {
|
||||
}
|
||||
|
||||
// Antigravity 直接支持的模型(精确匹配透传)
|
||||
// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列
|
||||
var antigravitySupportedModels = map[string]bool{
|
||||
"claude-opus-4-5-thinking": true,
|
||||
"claude-sonnet-4-5": true,
|
||||
"claude-sonnet-4-5-thinking": true,
|
||||
"gemini-2.5-flash": true,
|
||||
"gemini-2.5-flash-lite": true,
|
||||
"gemini-2.5-flash-thinking": true,
|
||||
"gemini-3-flash": true,
|
||||
"gemini-3-pro-low": true,
|
||||
"gemini-3-pro-high": true,
|
||||
@@ -288,23 +286,32 @@ var antigravitySupportedModels = map[string]bool{
|
||||
|
||||
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
|
||||
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
|
||||
// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5)
|
||||
var antigravityPrefixMapping = []struct {
|
||||
prefix string
|
||||
target string
|
||||
}{
|
||||
// 长前缀优先
|
||||
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
|
||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
|
||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
||||
// gemini-2.5 → gemini-3 映射(长前缀优先)
|
||||
{"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash
|
||||
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image
|
||||
{"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash
|
||||
{"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash
|
||||
{"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high
|
||||
{"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high
|
||||
{"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high
|
||||
// gemini-3 前缀映射
|
||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
|
||||
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
|
||||
// Claude 映射
|
||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
||||
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
|
||||
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
|
||||
{"claude-sonnet-4", "claude-sonnet-4-5"},
|
||||
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
|
||||
{"claude-opus-4", "claude-opus-4-5-thinking"},
|
||||
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
|
||||
}
|
||||
|
||||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||||
@@ -1530,7 +1537,11 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||
|
||||
func antigravityUseScopeRateLimit() bool {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
// 默认开启按配额域限流,只有明确设置为禁用值时才关闭
|
||||
if v == "0" || v == "false" || v == "no" || v == "off" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||
|
||||
@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
|
||||
// 3. Gemini 透传
|
||||
// 3. Gemini 2.5 → 3 映射
|
||||
{
|
||||
name: "Gemini透传 - gemini-2.5-flash",
|
||||
name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
|
||||
requestedModel: "gemini-2.5-flash",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-2.5-flash",
|
||||
expected: "gemini-3-flash",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-2.5-pro",
|
||||
name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
|
||||
requestedModel: "gemini-2.5-pro",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-2.5-pro",
|
||||
expected: "gemini-3-pro-high",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-future-model",
|
||||
|
||||
@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
result.Email = userInfo.Email
|
||||
}
|
||||
|
||||
// 获取 project_id(部分账户类型可能没有)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
|
||||
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
result.ProjectID = loadResp.CloudAICompanionProject
|
||||
// 获取 project_id(部分账户类型可能没有),失败时重试
|
||||
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
|
||||
if loadErr != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
|
||||
result.ProjectIDMissing = true
|
||||
} else {
|
||||
result.ProjectID = projectID
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
|
||||
tokenInfo.Email = existingEmail
|
||||
}
|
||||
|
||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
|
||||
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
|
||||
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
|
||||
|
||||
if loadErr != nil {
|
||||
// LoadCodeAssist 失败,保留原有 project_id
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
|
||||
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
|
||||
if existingProjectID == "" {
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
}
|
||||
} else {
|
||||
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// loadProjectIDWithRetry 带重试机制获取 project_id
|
||||
// 返回 project_id 和错误,失败时会重试指定次数
|
||||
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// 指数退避:1s, 2s, 4s
|
||||
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
|
||||
if backoff > 8*time.Second {
|
||||
backoff = 8 * time.Second
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
return loadResp.CloudAICompanionProject, nil
|
||||
}
|
||||
|
||||
// 记录错误
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
} else if loadResp == nil {
|
||||
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
|
||||
} else {
|
||||
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
|
||||
@@ -89,3 +89,30 @@ func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *tim
|
||||
}
|
||||
return &resetAt
|
||||
}
|
||||
|
||||
var antigravityAllScopes = []AntigravityQuotaScope{
|
||||
AntigravityQuotaScopeClaude,
|
||||
AntigravityQuotaScopeGeminiText,
|
||||
AntigravityQuotaScopeGeminiImage,
|
||||
}
|
||||
|
||||
func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
|
||||
if a == nil || a.Platform != PlatformAntigravity {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
result := make(map[string]int64)
|
||||
for _, scope := range antigravityAllScopes {
|
||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||
if resetAt != nil && now.Before(*resetAt) {
|
||||
remainingSec := int64(time.Until(*resetAt).Seconds())
|
||||
if remainingSec > 0 {
|
||||
result[string(scope)] = remainingSec
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
||||
}
|
||||
|
||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
// 合并旧的 credentials,保留新 credentials 中不存在的字段
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
|
||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||
newCredentials["project_id"] = oldProjectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,只记录警告,不返回错误
|
||||
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
|
||||
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity")
|
||||
if tokenInfo.ProjectID != "" {
|
||||
// 有旧的 project_id,本次获取失败,保留旧值
|
||||
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
|
||||
} else {
|
||||
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
|
||||
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
|
||||
@@ -19,17 +19,19 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
||||
)
|
||||
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
@@ -47,6 +49,7 @@ type JWTClaims struct {
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
@@ -58,6 +61,7 @@ type AuthService struct {
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(
|
||||
userRepo UserRepository,
|
||||
redeemRepo RedeemCodeRepository,
|
||||
cfg *config.Config,
|
||||
settingService *SettingService,
|
||||
emailService *EmailService,
|
||||
@@ -67,6 +71,7 @@ func NewAuthService(
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
cfg: cfg,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
@@ -78,11 +83,11 @@ func NewAuthService(
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "")
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "", "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
|
||||
// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查是否需要邀请码
|
||||
var invitationRedeemCode *RedeemCode
|
||||
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||
if invitationCode == "" {
|
||||
return "", nil, ErrInvitationCodeRequired
|
||||
}
|
||||
// 验证邀请码
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err)
|
||||
return "", nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
// 检查类型和状态
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
|
||||
return "", nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
invitationRedeemCode = redeemCode
|
||||
}
|
||||
|
||||
// 检查是否需要邮件验证
|
||||
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
||||
@@ -153,6 +178,14 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 标记邀请码为已使用(如果使用了邀请码)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
// 邀请码标记失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供且功能已启用)
|
||||
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
@@ -580,3 +613,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
||||
// 生成新token
|
||||
return s.GenerateToken(user)
|
||||
}
|
||||
|
||||
// IsPasswordResetEnabled 检查是否启用密码重置功能
|
||||
// 要求:必须同时开启邮件验证且 SMTP 配置正确
|
||||
func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return false
|
||||
}
|
||||
// Must have email verification enabled and SMTP configured
|
||||
if !s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
return s.settingService.IsPasswordResetEnabled(ctx)
|
||||
}
|
||||
|
||||
// preparePasswordReset validates the password reset request and returns necessary data
|
||||
// Returns (siteName, resetURL, shouldProceed)
|
||||
// shouldProceed is false when we should silently return success (to prevent enumeration)
|
||||
func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) {
|
||||
// Check if user exists (but don't reveal this to the caller)
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// Security: Log but don't reveal that user doesn't exist
|
||||
log.Printf("[Auth] Password reset requested for non-existent email: %s", email)
|
||||
return "", "", false
|
||||
}
|
||||
log.Printf("[Auth] Database error checking email for password reset: %v", err)
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive() {
|
||||
log.Printf("[Auth] Password reset requested for inactive user: %s", email)
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Get site name
|
||||
siteName := "Sub2API"
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
|
||||
// Build reset URL base
|
||||
resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/"))
|
||||
|
||||
return siteName, resetURL, true
|
||||
}
|
||||
|
||||
// RequestPasswordReset 请求密码重置(同步发送)
|
||||
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
if s.emailService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
|
||||
if !shouldProceed {
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||
log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err)
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset email sent to: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
|
||||
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
if s.emailQueueService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
|
||||
if !shouldProceed {
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
|
||||
log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err)
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset email enqueued for: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// Security: Increments TokenVersion to invalidate all existing JWT tokens
|
||||
func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error {
|
||||
// Check if password reset is enabled
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
|
||||
if s.emailService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// Verify and consume the reset token (one-time use)
|
||||
if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return ErrInvalidResetToken // Token was valid but user was deleted
|
||||
}
|
||||
log.Printf("[Auth] Database error getting user for password reset: %v", err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive() {
|
||||
return ErrUserNotActive
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
hashedPassword, err := s.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// Update password and increment TokenVersion
|
||||
user.PasswordHash = hashedPassword
|
||||
user.TokenVersion++ // Invalidate all existing tokens
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset successful for user: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
@@ -95,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
|
||||
return NewAuthService(
|
||||
repo,
|
||||
nil, // redeemRepo
|
||||
cfg,
|
||||
settingService,
|
||||
emailService,
|
||||
@@ -132,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
@@ -144,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
|
||||
@@ -158,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
|
||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||
require.ErrorContains(t, err, "verify code")
|
||||
}
|
||||
|
||||
@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
|
||||
return s.CalculateCost(model, tokens, multiplier)
|
||||
}
|
||||
|
||||
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
|
||||
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
|
||||
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
|
||||
//
|
||||
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
|
||||
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
|
||||
// 范围内正常计费,范围外 × 2 计费
|
||||
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
|
||||
// 未启用长上下文计费,直接走正常计费
|
||||
if threshold <= 0 || extraMultiplier <= 1 {
|
||||
return s.CalculateCost(model, tokens, rateMultiplier)
|
||||
}
|
||||
|
||||
// 计算总输入 token(缓存读取 + 新输入)
|
||||
total := tokens.CacheReadTokens + tokens.InputTokens
|
||||
if total <= threshold {
|
||||
return s.CalculateCost(model, tokens, rateMultiplier)
|
||||
}
|
||||
|
||||
// 拆分成范围内和范围外
|
||||
var inRangeCacheTokens, inRangeInputTokens int
|
||||
var outRangeCacheTokens, outRangeInputTokens int
|
||||
|
||||
if tokens.CacheReadTokens >= threshold {
|
||||
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
|
||||
inRangeCacheTokens = threshold
|
||||
inRangeInputTokens = 0
|
||||
outRangeCacheTokens = tokens.CacheReadTokens - threshold
|
||||
outRangeInputTokens = tokens.InputTokens
|
||||
} else {
|
||||
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
|
||||
inRangeCacheTokens = tokens.CacheReadTokens
|
||||
inRangeInputTokens = threshold - tokens.CacheReadTokens
|
||||
outRangeCacheTokens = 0
|
||||
outRangeInputTokens = tokens.InputTokens - inRangeInputTokens
|
||||
}
|
||||
|
||||
// 范围内部分:正常计费
|
||||
inRangeTokens := UsageTokens{
|
||||
InputTokens: inRangeInputTokens,
|
||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
}
|
||||
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 范围外部分:× extraMultiplier 计费
|
||||
outRangeTokens := UsageTokens{
|
||||
InputTokens: outRangeInputTokens,
|
||||
CacheReadTokens: outRangeCacheTokens,
|
||||
}
|
||||
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
|
||||
if err != nil {
|
||||
return inRangeCost, nil // 出错时返回范围内成本
|
||||
}
|
||||
|
||||
// 合并成本
|
||||
return &CostBreakdown{
|
||||
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
|
||||
OutputCost: inRangeCost.OutputCost,
|
||||
CacheCreationCost: inRangeCost.CacheCreationCost,
|
||||
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
|
||||
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
|
||||
ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
||||
func (s *BillingService) ListSupportedModels() []string {
|
||||
models := make([]string, 0)
|
||||
|
||||
@@ -1,67 +1,70 @@
|
||||
package service
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
|
||||
// Status constants
|
||||
const (
|
||||
StatusActive = "active"
|
||||
StatusDisabled = "disabled"
|
||||
StatusError = "error"
|
||||
StatusUnused = "unused"
|
||||
StatusUsed = "used"
|
||||
StatusExpired = "expired"
|
||||
StatusActive = domain.StatusActive
|
||||
StatusDisabled = domain.StatusDisabled
|
||||
StatusError = domain.StatusError
|
||||
StatusUnused = domain.StatusUnused
|
||||
StatusUsed = domain.StatusUsed
|
||||
StatusExpired = domain.StatusExpired
|
||||
)
|
||||
|
||||
// Role constants
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
RoleAdmin = domain.RoleAdmin
|
||||
RoleUser = domain.RoleUser
|
||||
)
|
||||
|
||||
// Platform constants
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformSora = "sora"
|
||||
PlatformAnthropic = domain.PlatformAnthropic
|
||||
PlatformOpenAI = domain.PlatformOpenAI
|
||||
PlatformGemini = domain.PlatformGemini
|
||||
PlatformAntigravity = domain.PlatformAntigravity
|
||||
PlatformSora = domain.PlatformSora
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
const (
|
||||
RedeemTypeBalance = "balance"
|
||||
RedeemTypeConcurrency = "concurrency"
|
||||
RedeemTypeSubscription = "subscription"
|
||||
RedeemTypeBalance = domain.RedeemTypeBalance
|
||||
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
||||
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
||||
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
const (
|
||||
PromoCodeStatusActive = "active"
|
||||
PromoCodeStatusDisabled = "disabled"
|
||||
PromoCodeStatusActive = domain.PromoCodeStatusActive
|
||||
PromoCodeStatusDisabled = domain.PromoCodeStatusDisabled
|
||||
)
|
||||
|
||||
// Admin adjustment type constants
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
|
||||
AdjustmentTypeAdminBalance = domain.AdjustmentTypeAdminBalance // 管理员调整余额
|
||||
AdjustmentTypeAdminConcurrency = domain.AdjustmentTypeAdminConcurrency // 管理员调整并发数
|
||||
)
|
||||
|
||||
// Group subscription type constants
|
||||
const (
|
||||
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
|
||||
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
|
||||
SubscriptionTypeStandard = domain.SubscriptionTypeStandard // 标准计费模式(按余额扣费)
|
||||
SubscriptionTypeSubscription = domain.SubscriptionTypeSubscription // 订阅模式(按限额控制)
|
||||
)
|
||||
|
||||
// Subscription status constants
|
||||
const (
|
||||
SubscriptionStatusActive = "active"
|
||||
SubscriptionStatusExpired = "expired"
|
||||
SubscriptionStatusSuspended = "suspended"
|
||||
SubscriptionStatusActive = domain.SubscriptionStatusActive
|
||||
SubscriptionStatusExpired = domain.SubscriptionStatusExpired
|
||||
SubscriptionStatusSuspended = domain.SubscriptionStatusSuspended
|
||||
)
|
||||
|
||||
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||
@@ -70,9 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
@@ -88,6 +93,9 @@ const (
|
||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||
|
||||
// TOTP 双因素认证设置
|
||||
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
|
||||
|
||||
// LinuxDo Connect OAuth 登录设置
|
||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||
@@ -95,14 +103,16 @@ const (
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
|
||||
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口
|
||||
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src)
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
|
||||
@@ -8,11 +8,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Task type constants
|
||||
const (
|
||||
TaskTypeVerifyCode = "verify_code"
|
||||
TaskTypePasswordReset = "password_reset"
|
||||
)
|
||||
|
||||
// EmailTask 邮件发送任务
|
||||
type EmailTask struct {
|
||||
Email string
|
||||
SiteName string
|
||||
TaskType string // "verify_code"
|
||||
TaskType string // "verify_code" or "password_reset"
|
||||
ResetURL string // Only used for password_reset task type
|
||||
}
|
||||
|
||||
// EmailQueueService 异步邮件队列服务
|
||||
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
||||
defer cancel()
|
||||
|
||||
switch task.TaskType {
|
||||
case "verify_code":
|
||||
case TaskTypeVerifyCode:
|
||||
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
|
||||
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
||||
}
|
||||
case TaskTypePasswordReset:
|
||||
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
|
||||
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
|
||||
}
|
||||
default:
|
||||
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
|
||||
}
|
||||
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
task := EmailTask{
|
||||
Email: email,
|
||||
SiteName: siteName,
|
||||
TaskType: "verify_code",
|
||||
TaskType: TaskTypeVerifyCode,
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// EnqueuePasswordReset 将密码重置邮件任务加入队列
|
||||
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
|
||||
task := EmailTask{
|
||||
Email: email,
|
||||
SiteName: siteName,
|
||||
TaskType: TaskTypePasswordReset,
|
||||
ResetURL: resetURL,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.taskChan <- task:
|
||||
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("email queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止队列服务
|
||||
func (s *EmailQueueService) Stop() {
|
||||
close(s.stopChan)
|
||||
|
||||
@@ -3,11 +3,14 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -19,6 +22,9 @@ var (
|
||||
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
|
||||
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
|
||||
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
|
||||
|
||||
// Password reset errors
|
||||
ErrInvalidResetToken = infraerrors.BadRequest("INVALID_RESET_TOKEN", "invalid or expired password reset token")
|
||||
)
|
||||
|
||||
// EmailCache defines cache operations for email service
|
||||
@@ -26,6 +32,16 @@ type EmailCache interface {
|
||||
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
||||
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
||||
DeleteVerificationCode(ctx context.Context, email string) error
|
||||
|
||||
// Password reset token methods
|
||||
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
|
||||
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
|
||||
DeletePasswordResetToken(ctx context.Context, email string) error
|
||||
|
||||
// Password reset email cooldown methods
|
||||
// Returns true if in cooldown period (email was sent recently)
|
||||
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
|
||||
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// VerificationCodeData represents verification code data
|
||||
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// PasswordResetTokenData represents password reset token data
|
||||
type PasswordResetTokenData struct {
|
||||
Token string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
maxVerifyCodeAttempts = 5
|
||||
|
||||
// Password reset token settings
|
||||
passwordResetTokenTTL = 30 * time.Minute
|
||||
|
||||
// Password reset email cooldown (prevent email bombing)
|
||||
passwordResetEmailCooldown = 30 * time.Second
|
||||
)
|
||||
|
||||
// SMTPConfig SMTP配置
|
||||
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
|
||||
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
||||
data.Attempts++
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
log.Printf("[Email] Failed to update verification attempt count: %v", err)
|
||||
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
|
||||
return client.Quit()
|
||||
}
|
||||
|
||||
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
|
||||
func (s *EmailService) GeneratePasswordResetToken() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// SendPasswordResetEmail sends a password reset email with a reset link
|
||||
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
|
||||
var token string
|
||||
var needSaveToken bool
|
||||
|
||||
// Check if token already exists
|
||||
existing, err := s.cache.GetPasswordResetToken(ctx, email)
|
||||
if err == nil && existing != nil {
|
||||
// Token exists, reuse it (allows resending email without generating new token)
|
||||
token = existing.Token
|
||||
needSaveToken = false
|
||||
} else {
|
||||
// Generate new token
|
||||
token, err = s.GeneratePasswordResetToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
needSaveToken = true
|
||||
}
|
||||
|
||||
// Save token to Redis (only if new token generated)
|
||||
if needSaveToken {
|
||||
data := &PasswordResetTokenData{
|
||||
Token: token,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := s.cache.SetPasswordResetToken(ctx, email, data, passwordResetTokenTTL); err != nil {
|
||||
return fmt.Errorf("save reset token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build full reset URL with URL-encoded token and email
|
||||
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
|
||||
|
||||
// Build email content
|
||||
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
|
||||
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
|
||||
|
||||
// Send email
|
||||
if err := s.SendEmail(ctx, email, subject, body); err != nil {
|
||||
return fmt.Errorf("send email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
|
||||
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
|
||||
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
|
||||
// Check email cooldown to prevent email bombing
|
||||
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
|
||||
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
|
||||
return nil // Silent success to prevent revealing cooldown to attackers
|
||||
}
|
||||
|
||||
// Send email using core method
|
||||
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set cooldown marker (Redis TTL handles expiration)
|
||||
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
|
||||
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyPasswordResetToken verifies the password reset token without consuming it
|
||||
func (s *EmailService) VerifyPasswordResetToken(ctx context.Context, email, token string) error {
|
||||
data, err := s.cache.GetPasswordResetToken(ctx, email)
|
||||
if err != nil || data == nil {
|
||||
return ErrInvalidResetToken
|
||||
}
|
||||
|
||||
// Use constant-time comparison to prevent timing attacks
|
||||
if subtle.ConstantTimeCompare([]byte(data.Token), []byte(token)) != 1 {
|
||||
return ErrInvalidResetToken
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
|
||||
func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, token string) error {
|
||||
// Verify first
|
||||
if err := s.VerifyPasswordResetToken(ctx, email, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete after verification (one-time use)
|
||||
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
|
||||
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildPasswordResetEmailBody builds the HTML content for password reset email
|
||||
func (s *EmailService) buildPasswordResetEmailBody(resetURL, siteName string) string {
|
||||
return fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
|
||||
.header h1 { margin: 0; font-size: 24px; }
|
||||
.content { padding: 40px 30px; text-align: center; }
|
||||
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
|
||||
.button:hover { opacity: 0.9; }
|
||||
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
|
||||
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
|
||||
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||
.warning { color: #e74c3c; font-weight: 500; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>%s</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p style="font-size: 18px; color: #333;">密码重置请求</p>
|
||||
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
|
||||
<a href="%s" class="button">重置密码</a>
|
||||
<div class="info">
|
||||
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
|
||||
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
|
||||
</div>
|
||||
<div class="link-fallback">
|
||||
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
|
||||
<p>%s</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>这是一封自动发送的邮件,请勿回复。</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`, siteName, resetURL, resetURL)
|
||||
}
|
||||
|
||||
23
backend/internal/service/gateway_beta_test.go
Normal file
23
backend/internal/service/gateway_beta_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMergeAnthropicBeta(t *testing.T) {
|
||||
got := mergeAnthropicBeta(
|
||||
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
|
||||
"foo, oauth-2025-04-20,bar, foo",
|
||||
)
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got)
|
||||
}
|
||||
|
||||
func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
|
||||
got := mergeAnthropicBeta(
|
||||
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
|
||||
"",
|
||||
)
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
|
||||
}
|
||||
@@ -269,6 +269,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
62
backend/internal/service/gateway_oauth_metadata_test.go
Normal file
62
backend/internal/service/gateway_oauth_metadata_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Stream: true,
|
||||
MetadataUserID: "",
|
||||
System: nil,
|
||||
Messages: nil,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id
|
||||
}
|
||||
|
||||
fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format
|
||||
|
||||
got := svc.buildOAuthMetadataUserID(parsed, account, fp)
|
||||
require.NotEmpty(t, got)
|
||||
|
||||
// Legacy format: user_{client}_account__session_{uuid}
|
||||
re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`)
|
||||
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
|
||||
}
|
||||
|
||||
func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Stream: true,
|
||||
MetadataUserID: "",
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"account_uuid": "acc-uuid",
|
||||
"claude_user_id": "clientid123",
|
||||
"anthropic_user_id": "",
|
||||
},
|
||||
}
|
||||
|
||||
got := svc.buildOAuthMetadataUserID(parsed, account, nil)
|
||||
require.NotEmpty(t, got)
|
||||
|
||||
// New format: user_{client}_account_{account_uuid}_session_{uuid}
|
||||
re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`)
|
||||
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||
claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||
system: "Custom prompt",
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Custom prompt",
|
||||
wantSecondText: claudePrefix + "\n\nCustom prompt",
|
||||
},
|
||||
{
|
||||
name: "string system equals Claude Code prompt",
|
||||
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||
// Claude Code + Custom = 2
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Custom",
|
||||
wantSecondText: claudePrefix + "\n\nCustom",
|
||||
},
|
||||
{
|
||||
name: "array system with existing Claude Code prompt (should dedupe)",
|
||||
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||
// Claude Code at start + Other = 2 (deduped)
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Other",
|
||||
wantSecondText: claudePrefix + "\n\nOther",
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
|
||||
21
backend/internal/service/gateway_sanitize_test.go
Normal file
21
backend/internal/service/gateway_sanitize_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
|
||||
in := "You are OpenCode, the best coding agent on the planet."
|
||||
got := sanitizeSystemText(in)
|
||||
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
|
||||
}
|
||||
|
||||
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
|
||||
in := "OpenCode and opencode are mentioned."
|
||||
got := sanitizeToolDescription(in)
|
||||
// We no longer rewrite tool descriptions; only redact obvious path leaks.
|
||||
require.Equal(t, in, got)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -36,6 +36,11 @@ const (
|
||||
geminiRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
|
||||
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
|
||||
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
|
||||
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
const geminiDummyThoughtSignature = "skip_thought_signature_validator"
|
||||
|
||||
type GeminiMessagesCompatService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
if err != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
}
|
||||
geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq)
|
||||
originalClaudeBody := body
|
||||
|
||||
proxyURL := ""
|
||||
@@ -931,6 +937,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
}
|
||||
|
||||
// 图片生成计费
|
||||
imageCount := 0
|
||||
imageSize := s.extractImageSize(body)
|
||||
if isImageGenerationModel(originalModel) {
|
||||
imageCount = 1
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
@@ -938,6 +951,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
Stream: req.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -969,6 +984,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
||||
}
|
||||
|
||||
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
|
||||
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
|
||||
body = ensureGeminiFunctionCallThoughtSignatures(body)
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
@@ -1371,6 +1390,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
usage = &ClaudeUsage{}
|
||||
}
|
||||
|
||||
// 图片生成计费
|
||||
imageCount := 0
|
||||
imageSize := s.extractImageSize(body)
|
||||
if isImageGenerationModel(originalModel) {
|
||||
imageCount = 1
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
@@ -1378,6 +1404,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -2504,9 +2532,13 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
|
||||
}
|
||||
prompt, _ := asInt(usageMeta["promptTokenCount"])
|
||||
cand, _ := asInt(usageMeta["candidatesTokenCount"])
|
||||
cached, _ := asInt(usageMeta["cachedContentTokenCount"])
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
return &ClaudeUsage{
|
||||
InputTokens: prompt,
|
||||
OutputTokens: cand,
|
||||
InputTokens: prompt - cached,
|
||||
OutputTokens: cand,
|
||||
CacheReadInputTokens: cached,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2635,6 +2667,58 @@ func nextGeminiDailyResetUnix() *int64 {
|
||||
return &ts
|
||||
}
|
||||
|
||||
func ensureGeminiFunctionCallThoughtSignatures(body []byte) []byte {
|
||||
// Fast path: only run when functionCall is present.
|
||||
if !bytes.Contains(body, []byte(`"functionCall"`)) {
|
||||
return body
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
contentsAny, ok := payload["contents"].([]any)
|
||||
if !ok || len(contentsAny) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
for _, c := range contentsAny {
|
||||
cm, ok := c.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
partsAny, ok := cm["parts"].([]any)
|
||||
if !ok || len(partsAny) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, p := range partsAny {
|
||||
pm, ok := p.(map[string]any)
|
||||
if !ok || pm == nil {
|
||||
continue
|
||||
}
|
||||
if fc, ok := pm["functionCall"].(map[string]any); !ok || fc == nil {
|
||||
continue
|
||||
}
|
||||
ts, _ := pm["thoughtSignature"].(string)
|
||||
if strings.TrimSpace(ts) == "" {
|
||||
pm["thoughtSignature"] = geminiDummyThoughtSignature
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func extractGeminiFinishReason(geminiResp map[string]any) string {
|
||||
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if cand, ok := candidates[0].(map[string]any); ok {
|
||||
@@ -2834,7 +2918,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
|
||||
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
|
||||
toolUseIDToName[id] = name
|
||||
}
|
||||
signature, _ := bm["signature"].(string)
|
||||
signature = strings.TrimSpace(signature)
|
||||
if signature == "" {
|
||||
signature = geminiDummyThoughtSignature
|
||||
}
|
||||
parts = append(parts, map[string]any{
|
||||
"thoughtSignature": signature,
|
||||
"functionCall": map[string]any{
|
||||
"name": name,
|
||||
"args": bm["input"],
|
||||
@@ -3031,3 +3121,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
||||
func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string {
|
||||
var req struct {
|
||||
GenerationConfig *struct {
|
||||
ImageConfig *struct {
|
||||
ImageSize string `json:"imageSize"`
|
||||
} `json:"imageConfig"`
|
||||
} `json:"generationConfig"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return "2K"
|
||||
}
|
||||
|
||||
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
|
||||
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
|
||||
if size == "1K" || size == "2K" || size == "4K" {
|
||||
return size
|
||||
}
|
||||
}
|
||||
|
||||
return "2K"
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||
claudeReq := map[string]any{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"max_tokens": 10,
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "hi"},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "ok"},
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "default_api:write_file",
|
||||
"input": map[string]any{"path": "a.txt", "content": "x"},
|
||||
// no signature on purpose
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"name": "default_api:write_file",
|
||||
"description": "write file",
|
||||
"input_schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{"path": map[string]any{"type": "string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(claudeReq)
|
||||
|
||||
out, err := convertClaudeMessagesToGeminiGenerateContent(b)
|
||||
if err != nil {
|
||||
t.Fatalf("convert failed: %v", err)
|
||||
}
|
||||
s := string(out)
|
||||
if !strings.Contains(s, "\"functionCall\"") {
|
||||
t.Fatalf("expected functionCall in output, got: %s", s)
|
||||
}
|
||||
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
|
||||
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing.T) {
|
||||
geminiReq := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": "default_api:write_file",
|
||||
"args": map[string]any{"path": "a.txt"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(geminiReq)
|
||||
out := ensureGeminiFunctionCallThoughtSignatures(b)
|
||||
s := string(out)
|
||||
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
|
||||
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,6 +221,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
||||
|
||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||
|
||||
72
backend/internal/service/gemini_native_signature_cleaner.go
Normal file
72
backend/internal/service/gemini_native_signature_cleaner.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
|
||||
// 以避免跨账号签名验证错误。
|
||||
//
|
||||
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
|
||||
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
|
||||
//
|
||||
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
|
||||
// to avoid cross-account signature validation errors.
|
||||
//
|
||||
// When sticky session switches accounts (e.g., original account becomes unavailable),
|
||||
// thoughtSignatures from the old account will cause validation failures on the new account.
|
||||
// By removing these signatures, we allow the new account to generate valid signatures.
|
||||
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
// 解析 JSON
|
||||
var data any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
// 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确)
|
||||
return body
|
||||
}
|
||||
|
||||
// 递归清理 thoughtSignature
|
||||
cleaned := cleanThoughtSignaturesRecursive(data)
|
||||
|
||||
// 重新序列化
|
||||
result, err := json.Marshal(cleaned)
|
||||
if err != nil {
|
||||
// 如果序列化失败,返回原始 body
|
||||
return body
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
|
||||
func cleanThoughtSignaturesRecursive(data any) any {
|
||||
switch v := data.(type) {
|
||||
case map[string]any:
|
||||
// 创建新的 map,移除 thoughtSignature
|
||||
result := make(map[string]any, len(v))
|
||||
for key, value := range v {
|
||||
// 跳过 thoughtSignature 字段
|
||||
if key == "thoughtSignature" {
|
||||
continue
|
||||
}
|
||||
// 递归处理嵌套结构
|
||||
result[key] = cleanThoughtSignaturesRecursive(value)
|
||||
}
|
||||
return result
|
||||
|
||||
case []any:
|
||||
// 递归处理数组中的每个元素
|
||||
result := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = cleanThoughtSignaturesRecursive(item)
|
||||
}
|
||||
return result
|
||||
|
||||
default:
|
||||
// 基本类型(string, number, bool, null)直接返回
|
||||
return v
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,10 @@ type GroupRepository interface {
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
|
||||
// BindAccountsToGroup 将多个账号绑定到指定分组
|
||||
BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error
|
||||
}
|
||||
|
||||
// CreateGroupRequest 创建分组请求
|
||||
|
||||
@@ -26,13 +26,13 @@ var (
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
||||
UserAgent: "claude-cli/2.1.22 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.52.0",
|
||||
StainlessPackageVersion: "0.70.0",
|
||||
StainlessOS: "Linux",
|
||||
StainlessArch: "x64",
|
||||
StainlessArch: "arm64",
|
||||
StainlessRuntime: "node",
|
||||
StainlessRuntimeVersion: "v22.14.0",
|
||||
StainlessRuntimeVersion: "v24.13.0",
|
||||
}
|
||||
|
||||
// Fingerprint represents account fingerprint data
|
||||
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
|
||||
}
|
||||
|
||||
// parseUserAgentVersion 解析user-agent版本号
|
||||
// 例如:claude-cli/2.0.62 -> (2, 0, 62)
|
||||
// 例如:claude-cli/2.1.2 -> (2, 1, 2)
|
||||
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
||||
// 匹配 xxx/x.y.z 格式
|
||||
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
||||
|
||||
@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
|
||||
type NormalizedCodexLimits struct {
|
||||
Used5hPercent *float64
|
||||
Reset5hSeconds *int
|
||||
Window5hMinutes *int
|
||||
Used7dPercent *float64
|
||||
Reset7dSeconds *int
|
||||
Window7dMinutes *int
|
||||
}
|
||||
|
||||
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
|
||||
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
|
||||
// Returns nil if snapshot is nil or has no useful data.
|
||||
func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := &NormalizedCodexLimits{}
|
||||
|
||||
primaryMins := 0
|
||||
secondaryMins := 0
|
||||
hasPrimaryWindow := false
|
||||
hasSecondaryWindow := false
|
||||
|
||||
if s.PrimaryWindowMinutes != nil {
|
||||
primaryMins = *s.PrimaryWindowMinutes
|
||||
hasPrimaryWindow = true
|
||||
}
|
||||
if s.SecondaryWindowMinutes != nil {
|
||||
secondaryMins = *s.SecondaryWindowMinutes
|
||||
hasSecondaryWindow = true
|
||||
}
|
||||
|
||||
// Determine mapping based on window_minutes
|
||||
use5hFromPrimary := false
|
||||
use7dFromPrimary := false
|
||||
|
||||
if hasPrimaryWindow && hasSecondaryWindow {
|
||||
// Both known: smaller window is 5h, larger is 7d
|
||||
if primaryMins < secondaryMins {
|
||||
use5hFromPrimary = true
|
||||
} else {
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
} else if hasPrimaryWindow {
|
||||
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
|
||||
if primaryMins <= 360 {
|
||||
use5hFromPrimary = true
|
||||
} else {
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
} else if hasSecondaryWindow {
|
||||
// Only secondary known: classify by threshold
|
||||
if secondaryMins <= 360 {
|
||||
// 5h from secondary, so primary (if any data) is 7d
|
||||
use7dFromPrimary = true
|
||||
} else {
|
||||
// 7d from secondary, so primary (if any data) is 5h
|
||||
use5hFromPrimary = true
|
||||
}
|
||||
} else {
|
||||
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
|
||||
// Assign values
|
||||
if use5hFromPrimary {
|
||||
result.Used5hPercent = s.PrimaryUsedPercent
|
||||
result.Reset5hSeconds = s.PrimaryResetAfterSeconds
|
||||
result.Window5hMinutes = s.PrimaryWindowMinutes
|
||||
result.Used7dPercent = s.SecondaryUsedPercent
|
||||
result.Reset7dSeconds = s.SecondaryResetAfterSeconds
|
||||
result.Window7dMinutes = s.SecondaryWindowMinutes
|
||||
} else if use7dFromPrimary {
|
||||
result.Used7dPercent = s.PrimaryUsedPercent
|
||||
result.Reset7dSeconds = s.PrimaryResetAfterSeconds
|
||||
result.Window7dMinutes = s.PrimaryWindowMinutes
|
||||
result.Used5hPercent = s.SecondaryUsedPercent
|
||||
result.Reset5hSeconds = s.SecondaryResetAfterSeconds
|
||||
result.Window5hMinutes = s.SecondaryWindowMinutes
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// OpenAIUsage represents OpenAI API response usage
|
||||
type OpenAIUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
@@ -70,12 +156,15 @@ type OpenAIUsage struct {
|
||||
|
||||
// OpenAIForwardResult represents the result of forwarding
|
||||
type OpenAIForwardResult struct {
|
||||
RequestID string
|
||||
Usage OpenAIUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
RequestID string
|
||||
Usage OpenAIUsage
|
||||
Model string
|
||||
// ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
|
||||
// Stored for usage records display; nil means not provided / not applicable.
|
||||
ReasoningEffort *string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
}
|
||||
|
||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||
@@ -756,6 +845,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
bodyModified = true
|
||||
}
|
||||
}
|
||||
|
||||
// Remove prompt_cache_retention (not supported by upstream OpenAI API)
|
||||
if _, has := reqBody["prompt_cache_retention"]; has {
|
||||
delete(reqBody, "prompt_cache_retention")
|
||||
bodyModified = true
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
@@ -867,18 +962,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if account.Type == AccountTypeOAuth {
|
||||
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1174,15 +1272,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
||||
lastDataAt := time.Now()
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱。
|
||||
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
|
||||
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
||||
errorEventSent := false
|
||||
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
if errorEventSent || clientDisconnected {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
payload := map[string]any{
|
||||
"type": "error",
|
||||
"sequence_number": 0,
|
||||
"error": map[string]any{
|
||||
"type": "upstream_error",
|
||||
"message": reason,
|
||||
"code": reason,
|
||||
},
|
||||
}
|
||||
if b, err := json.Marshal(payload); err == nil {
|
||||
_, _ = fmt.Fprintf(w, "data: %s\n\n", b)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
@@ -1194,6 +1306,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
log.Printf("Context canceled during streaming, returning collected usage")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
@@ -1217,15 +1340,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
||||
data = correctedData
|
||||
line = "data: " + correctedData
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
// 写入客户端(客户端断开后继续 drain 上游)
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
@@ -1235,11 +1362,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// Forward non-data lines as-is
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
@@ -1247,6 +1377,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
if s.rateLimitService != nil {
|
||||
@@ -1256,11 +1390,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
continue
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
@@ -1601,6 +1740,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
@@ -1665,8 +1805,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractCodexUsageHeaders extracts Codex usage limits from response headers
|
||||
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
|
||||
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
|
||||
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||
snapshot := &OpenAICodexUsageSnapshot{}
|
||||
hasData := false
|
||||
|
||||
@@ -1740,6 +1881,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
|
||||
// Convert snapshot to map for merging into Extra
|
||||
updates := make(map[string]any)
|
||||
|
||||
// Save raw primary/secondary fields for debugging/tracing
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
@@ -1763,109 +1906,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
}
|
||||
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
|
||||
|
||||
// Normalize to canonical 5h/7d fields based on window_minutes
|
||||
// This fixes the issue where OpenAI's primary/secondary naming is reversed
|
||||
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
|
||||
|
||||
// IMPORTANT: We can only reliably determine window type from window_minutes field
|
||||
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
|
||||
|
||||
var primaryWindowMins, secondaryWindowMins int
|
||||
var hasPrimaryWindow, hasSecondaryWindow bool
|
||||
|
||||
// Only use window_minutes for reliable window size comparison
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
primaryWindowMins = *snapshot.PrimaryWindowMinutes
|
||||
hasPrimaryWindow = true
|
||||
}
|
||||
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
secondaryWindowMins = *snapshot.SecondaryWindowMinutes
|
||||
hasSecondaryWindow = true
|
||||
}
|
||||
|
||||
// Determine which is 5h and which is 7d
|
||||
var use5hFromPrimary, use7dFromPrimary bool
|
||||
var use5hFromSecondary, use7dFromSecondary bool
|
||||
|
||||
if hasPrimaryWindow && hasSecondaryWindow {
|
||||
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
|
||||
if primaryWindowMins < secondaryWindowMins {
|
||||
use5hFromPrimary = true
|
||||
use7dFromSecondary = true
|
||||
} else {
|
||||
use5hFromSecondary = true
|
||||
use7dFromPrimary = true
|
||||
// Normalize to canonical 5h/7d fields
|
||||
if normalized := snapshot.Normalize(); normalized != nil {
|
||||
if normalized.Used5hPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *normalized.Used5hPercent
|
||||
}
|
||||
} else if hasPrimaryWindow {
|
||||
// Only primary window size known: classify by absolute threshold
|
||||
if primaryWindowMins <= 360 {
|
||||
use5hFromPrimary = true
|
||||
} else {
|
||||
use7dFromPrimary = true
|
||||
if normalized.Reset5hSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds
|
||||
}
|
||||
} else if hasSecondaryWindow {
|
||||
// Only secondary window size known: classify by absolute threshold
|
||||
if secondaryWindowMins <= 360 {
|
||||
use5hFromSecondary = true
|
||||
} else {
|
||||
use7dFromSecondary = true
|
||||
if normalized.Window5hMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes
|
||||
}
|
||||
} else {
|
||||
// No window_minutes available: cannot reliably determine window types
|
||||
// Fall back to legacy assumption (may be incorrect)
|
||||
// Assume primary=7d, secondary=5h based on historical observation
|
||||
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
|
||||
use5hFromSecondary = true
|
||||
if normalized.Used7dPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *normalized.Used7dPercent
|
||||
}
|
||||
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
|
||||
use7dFromPrimary = true
|
||||
if normalized.Reset7dSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds
|
||||
}
|
||||
}
|
||||
|
||||
// Write canonical 5h fields
|
||||
if use5hFromPrimary {
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
||||
}
|
||||
} else if use5hFromSecondary {
|
||||
if snapshot.SecondaryUsedPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
|
||||
}
|
||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
||||
}
|
||||
}
|
||||
|
||||
// Write canonical 7d fields
|
||||
if use7dFromPrimary {
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
||||
}
|
||||
} else if use7dFromSecondary {
|
||||
if snapshot.SecondaryUsedPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
|
||||
}
|
||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
||||
if normalized.Window7dMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1876,3 +1935,86 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}()
|
||||
}
|
||||
|
||||
func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) {
|
||||
if reqBody == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Primary: reasoning.effort
|
||||
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
||||
if effort, ok := reasoning["effort"].(string); ok {
|
||||
return normalizeOpenAIReasoningEffort(effort), true
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: some clients may use a flat field.
|
||||
if effort, ok := reqBody["reasoning_effort"].(string); ok {
|
||||
return normalizeOpenAIReasoningEffort(effort), true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func deriveOpenAIReasoningEffortFromModel(model string) string {
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
modelID := strings.TrimSpace(model)
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
|
||||
parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
|
||||
switch r {
|
||||
case '-', '_', ' ':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
|
||||
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
return &value
|
||||
}
|
||||
|
||||
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
return &value
|
||||
}
|
||||
|
||||
func normalizeOpenAIReasoningEffort(raw string) string {
|
||||
value := strings.ToLower(strings.TrimSpace(raw))
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Normalize separators for "x-high"/"x_high" variants.
|
||||
value = strings.NewReplacer("-", "", "_", "", " ", "").Replace(value)
|
||||
|
||||
switch value {
|
||||
case "none", "minimal":
|
||||
return ""
|
||||
case "low", "medium", "high":
|
||||
return value
|
||||
case "xhigh", "extrahigh":
|
||||
return "xhigh"
|
||||
default:
|
||||
// Only store known effort levels for now to keep UI consistent.
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
|
||||
skipDefaultLoad bool
|
||||
}
|
||||
|
||||
type cancelReadCloser struct{}
|
||||
|
||||
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
|
||||
func (c cancelReadCloser) Close() error { return nil }
|
||||
|
||||
type failingGinWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *failingGinWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if c.acquireResults != nil {
|
||||
if result, ok := c.acquireResults[accountID]; ok {
|
||||
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
||||
t.Fatalf("expected stream timeout error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "stream_timeout") {
|
||||
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
|
||||
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") {
|
||||
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: cancelReadCloser{},
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
||||
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if result == nil || result.usage == nil {
|
||||
t.Fatalf("expected usage result")
|
||||
}
|
||||
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
|
||||
t.Fatalf("unexpected usage: %+v", *result.usage)
|
||||
}
|
||||
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
|
||||
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||
if !errors.Is(err, bufio.ErrTooLong) {
|
||||
t.Fatalf("expected ErrTooLong, got %v", err)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "response_too_large") {
|
||||
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
|
||||
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") {
|
||||
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
@@ -35,12 +36,12 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
// Generate PKCE values
|
||||
state, err := openai.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_STATE_FAILED", "failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := openai.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_VERIFIER_FAILED", "failed to generate code verifier: %v", err)
|
||||
}
|
||||
|
||||
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
|
||||
@@ -48,14 +49,17 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
// Generate session ID
|
||||
sessionID, err := openai.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_SESSION_FAILED", "failed to generate session ID: %v", err)
|
||||
}
|
||||
|
||||
// Get proxy URL if specified
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||||
}
|
||||
if proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
@@ -110,14 +114,17 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
// Get session
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||||
}
|
||||
if proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
@@ -131,7 +138,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
// Exchange code for token
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
@@ -201,12 +208,12 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
|
||||
// RefreshAccountToken refreshes token for an OpenAI account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if !account.IsOpenAI() {
|
||||
return nil, fmt.Errorf("account is not an OpenAI account")
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
|
||||
}
|
||||
|
||||
refreshToken := account.GetOpenAIRefreshToken()
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available")
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
|
||||
@@ -67,6 +67,8 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
||||
|
||||
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
|
||||
|
||||
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
|
||||
|
||||
if acc.Platform != "" {
|
||||
if _, ok := platform[acc.Platform]; !ok {
|
||||
platform[acc.Platform] = &PlatformAvailability{
|
||||
@@ -84,6 +86,14 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
||||
if hasError {
|
||||
p.ErrorCount++
|
||||
}
|
||||
if len(scopeRateLimits) > 0 {
|
||||
if p.ScopeRateLimitCount == nil {
|
||||
p.ScopeRateLimitCount = make(map[string]int64)
|
||||
}
|
||||
for scope := range scopeRateLimits {
|
||||
p.ScopeRateLimitCount[scope]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, grp := range acc.Groups {
|
||||
@@ -108,6 +118,14 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
||||
if hasError {
|
||||
g.ErrorCount++
|
||||
}
|
||||
if len(scopeRateLimits) > 0 {
|
||||
if g.ScopeRateLimitCount == nil {
|
||||
g.ScopeRateLimitCount = make(map[string]int64)
|
||||
}
|
||||
for scope := range scopeRateLimits {
|
||||
g.ScopeRateLimitCount[scope]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
displayGroupID := int64(0)
|
||||
@@ -140,6 +158,9 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
||||
item.RateLimitRemainingSec = &remainingSec
|
||||
}
|
||||
}
|
||||
if len(scopeRateLimits) > 0 {
|
||||
item.ScopeRateLimits = scopeRateLimits
|
||||
}
|
||||
if isOverloaded && acc.OverloadUntil != nil {
|
||||
item.OverloadUntil = acc.OverloadUntil
|
||||
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
|
||||
|
||||
@@ -39,22 +39,24 @@ type AccountConcurrencyInfo struct {
|
||||
|
||||
// PlatformAvailability aggregates account availability by platform.
|
||||
type PlatformAvailability struct {
|
||||
Platform string `json:"platform"`
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
AvailableCount int64 `json:"available_count"`
|
||||
RateLimitCount int64 `json:"rate_limit_count"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
Platform string `json:"platform"`
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
AvailableCount int64 `json:"available_count"`
|
||||
RateLimitCount int64 `json:"rate_limit_count"`
|
||||
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
}
|
||||
|
||||
// GroupAvailability aggregates account availability by group.
|
||||
type GroupAvailability struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Platform string `json:"platform"`
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
AvailableCount int64 `json:"available_count"`
|
||||
RateLimitCount int64 `json:"rate_limit_count"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Platform string `json:"platform"`
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
AvailableCount int64 `json:"available_count"`
|
||||
RateLimitCount int64 `json:"rate_limit_count"`
|
||||
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
}
|
||||
|
||||
// AccountAvailability represents current availability for a single account.
|
||||
@@ -72,10 +74,11 @@ type AccountAvailability struct {
|
||||
IsOverloaded bool `json:"is_overloaded"`
|
||||
HasError bool `json:"has_error"`
|
||||
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
|
||||
ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
|
||||
}
|
||||
|
||||
@@ -83,6 +83,7 @@ type OpsAdvancedSettings struct {
|
||||
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
|
||||
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
||||
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
||||
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
|
||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||
}
|
||||
|
||||
@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
|
||||
// handle429 处理429限流错误
|
||||
// 解析响应头获取重置时间,标记账号为限流状态
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||
// 解析重置时间戳
|
||||
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
|
||||
if account.Platform == PlatformOpenAI {
|
||||
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("openai_account_rate_limited", "account_id", account.ID, "reset_at", *resetAt)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 尝试从响应头解析重置时间(Anthropic)
|
||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||
|
||||
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
||||
if resetTimestamp == "" {
|
||||
switch account.Platform {
|
||||
case PlatformOpenAI:
|
||||
// 尝试解析 OpenAI 的 usage_limit_reached 错误
|
||||
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
|
||||
return
|
||||
}
|
||||
case PlatformGemini, PlatformAntigravity:
|
||||
// 尝试解析 Gemini 格式(用于其他平台)
|
||||
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 没有重置时间,使用默认5分钟
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
|
||||
@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
}
|
||||
return
|
||||
}
|
||||
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
|
||||
return strings.Contains(msg, "sonnet")
|
||||
}
|
||||
|
||||
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
|
||||
// 返回 nil 表示无法从响应头中确定重置时间
|
||||
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
|
||||
snapshot := ParseCodexRateLimitHeaders(headers)
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 判断哪个限制被触发(used_percent >= 100)
|
||||
is7dExhausted := normalized.Used7dPercent != nil && *normalized.Used7dPercent >= 100
|
||||
is5hExhausted := normalized.Used5hPercent != nil && *normalized.Used5hPercent >= 100
|
||||
|
||||
// 优先使用被触发限制的重置时间
|
||||
if is7dExhausted && normalized.Reset7dSeconds != nil {
|
||||
resetAt := now.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
|
||||
slog.Info("openai_429_7d_limit_exhausted", "reset_after_seconds", *normalized.Reset7dSeconds, "reset_at", resetAt)
|
||||
return &resetAt
|
||||
}
|
||||
if is5hExhausted && normalized.Reset5hSeconds != nil {
|
||||
resetAt := now.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
|
||||
slog.Info("openai_429_5h_limit_exhausted", "reset_after_seconds", *normalized.Reset5hSeconds, "reset_at", resetAt)
|
||||
return &resetAt
|
||||
}
|
||||
|
||||
// 都未达到100%但收到429,使用较长的重置时间
|
||||
var maxResetSecs int
|
||||
if normalized.Reset7dSeconds != nil && *normalized.Reset7dSeconds > maxResetSecs {
|
||||
maxResetSecs = *normalized.Reset7dSeconds
|
||||
}
|
||||
if normalized.Reset5hSeconds != nil && *normalized.Reset5hSeconds > maxResetSecs {
|
||||
maxResetSecs = *normalized.Reset5hSeconds
|
||||
}
|
||||
if maxResetSecs > 0 {
|
||||
resetAt := now.Add(time.Duration(maxResetSecs) * time.Second)
|
||||
slog.Info("openai_429_using_max_reset", "max_reset_seconds", maxResetSecs, "reset_at", resetAt)
|
||||
return &resetAt
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||
// OpenAI 的 usage_limit_reached 错误格式:
|
||||
//
|
||||
// {
|
||||
// "error": {
|
||||
// "message": "The usage limit has been reached",
|
||||
// "type": "usage_limit_reached",
|
||||
// "resets_at": 1769404154,
|
||||
// "resets_in_seconds": 133107
|
||||
// }
|
||||
// }
|
||||
func parseOpenAIRateLimitResetTime(body []byte) *int64 {
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
errObj, ok := parsed["error"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型
|
||||
errType, _ := errObj["type"].(string)
|
||||
if errType != "usage_limit_reached" && errType != "rate_limit_exceeded" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 优先使用 resets_at(Unix 时间戳)
|
||||
if resetsAt, ok := errObj["resets_at"].(float64); ok {
|
||||
ts := int64(resetsAt)
|
||||
return &ts
|
||||
}
|
||||
if resetsAt, ok := errObj["resets_at"].(string); ok {
|
||||
if ts, err := strconv.ParseInt(resetsAt, 10, 64); err == nil {
|
||||
return &ts
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 resets_at,尝试使用 resets_in_seconds
|
||||
if resetsInSeconds, ok := errObj["resets_in_seconds"].(float64); ok {
|
||||
ts := time.Now().Unix() + int64(resetsInSeconds)
|
||||
return &ts
|
||||
}
|
||||
if resetsInSeconds, ok := errObj["resets_in_seconds"].(string); ok {
|
||||
if sec, err := strconv.ParseInt(resetsInSeconds, 10, 64); err == nil {
|
||||
ts := time.Now().Unix() + sec
|
||||
return &ts
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handle529 处理529过载错误
|
||||
// 根据配置设置过载冷却时间
|
||||
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
||||
|
||||
364
backend/internal/service/ratelimit_service_openai_test.go
Normal file
364
backend/internal/service/ratelimit_service_openai_test.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate headers when 7d limit is exhausted (100% used)
|
||||
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "384607") // ~4.5 days
|
||||
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
|
||||
headers.Set("x-codex-secondary-used-percent", "3")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "17369") // ~4.8 hours
|
||||
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should be approximately 384607 seconds from now
|
||||
expectedDuration := 384607 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_5hExhausted(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate headers when 5h limit is exhausted (100% used)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "50")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "500000")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "3600") // 1 hour
|
||||
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should be approximately 3600 seconds from now
|
||||
expectedDuration := 3600 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Neither limit at 100%, should use the longer reset time
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "80")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "100000")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "90")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "5000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should use the max (100000 seconds from 7d window)
|
||||
expectedDuration := 100000 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_NoCodexHeaders(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// No codex headers at all
|
||||
headers := http.Header{}
|
||||
headers.Set("content-type", "application/json")
|
||||
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
|
||||
if resetAt != nil {
|
||||
t.Errorf("expected nil resetAt when no codex headers, got %v", resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100") // This is 5h
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "3600") // 1 hour
|
||||
headers.Set("x-codex-primary-window-minutes", "300") // 5 hours - smaller!
|
||||
headers.Set("x-codex-secondary-used-percent", "50")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "500000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "10080") // 7 days - larger!
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should correctly identify that primary is 5h (smaller window) and use its reset time
|
||||
expectedDuration := 3600 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits(t *testing.T) {
|
||||
// Test the Normalize() method directly
|
||||
pUsed := 100.0
|
||||
pReset := 384607
|
||||
pWindow := 10080
|
||||
sUsed := 3.0
|
||||
sReset := 17369
|
||||
sWindow := 300
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: &pUsed,
|
||||
PrimaryResetAfterSeconds: &pReset,
|
||||
PrimaryWindowMinutes: &pWindow,
|
||||
SecondaryUsedPercent: &sUsed,
|
||||
SecondaryResetAfterSeconds: &sReset,
|
||||
SecondaryWindowMinutes: &sWindow,
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Primary has larger window (10080 > 300), so primary should be 7d
|
||||
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
|
||||
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
|
||||
}
|
||||
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 384607 {
|
||||
t.Errorf("expected Reset7dSeconds=384607, got %v", normalized.Reset7dSeconds)
|
||||
}
|
||||
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 3.0 {
|
||||
t.Errorf("expected Used5hPercent=3, got %v", normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 17369 {
|
||||
t.Errorf("expected Reset5hSeconds=17369, got %v", normalized.Reset5hSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
|
||||
// Test when only primary has data, no window_minutes
|
||||
pUsed := 80.0
|
||||
pReset := 50000
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: &pUsed,
|
||||
PrimaryResetAfterSeconds: &pReset,
|
||||
// No window_minutes, no secondary data
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Legacy assumption: primary=7d, secondary=5h
|
||||
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 80.0 {
|
||||
t.Errorf("expected Used7dPercent=80, got %v", normalized.Used7dPercent)
|
||||
}
|
||||
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 50000 {
|
||||
t.Errorf("expected Reset7dSeconds=50000, got %v", normalized.Reset7dSeconds)
|
||||
}
|
||||
// Secondary (5h) should be nil
|
||||
if normalized.Used5hPercent != nil {
|
||||
t.Errorf("expected Used5hPercent=nil, got %v", *normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds != nil {
|
||||
t.Errorf("expected Reset5hSeconds=nil, got %v", *normalized.Reset5hSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
|
||||
// Test when only secondary has data, no window_minutes
|
||||
sUsed := 60.0
|
||||
sReset := 3000
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
SecondaryUsedPercent: &sUsed,
|
||||
SecondaryResetAfterSeconds: &sReset,
|
||||
// No window_minutes, no primary data
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Legacy assumption: primary=7d, secondary=5h
|
||||
// So secondary goes to 5h
|
||||
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 60.0 {
|
||||
t.Errorf("expected Used5hPercent=60, got %v", normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 3000 {
|
||||
t.Errorf("expected Reset5hSeconds=3000, got %v", normalized.Reset5hSeconds)
|
||||
}
|
||||
// Primary (7d) should be nil
|
||||
if normalized.Used7dPercent != nil {
|
||||
t.Errorf("expected Used7dPercent=nil, got %v", *normalized.Used7dPercent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits_BothDataNoWindowMinutes(t *testing.T) {
|
||||
// Test when both have data but no window_minutes
|
||||
pUsed := 100.0
|
||||
pReset := 400000
|
||||
sUsed := 50.0
|
||||
sReset := 10000
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: &pUsed,
|
||||
PrimaryResetAfterSeconds: &pReset,
|
||||
SecondaryUsedPercent: &sUsed,
|
||||
SecondaryResetAfterSeconds: &sReset,
|
||||
// No window_minutes
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Legacy assumption: primary=7d, secondary=5h
|
||||
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
|
||||
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
|
||||
}
|
||||
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 400000 {
|
||||
t.Errorf("expected Reset7dSeconds=400000, got %v", normalized.Reset7dSeconds)
|
||||
}
|
||||
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 50.0 {
|
||||
t.Errorf("expected Used5hPercent=50, got %v", normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 10000 {
|
||||
t.Errorf("expected Reset5hSeconds=10000, got %v", normalized.Reset5hSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandle429_AnthropicPlatformUnaffected(t *testing.T) {
|
||||
// Verify that Anthropic platform accounts still use the original logic
|
||||
// This test ensures we don't break existing Claude account rate limiting
|
||||
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate Anthropic 429 headers
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-reset", "1737820800") // A future Unix timestamp
|
||||
|
||||
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
|
||||
// because it only handles OpenAI platform
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
|
||||
// Should return nil since there are no x-codex-* headers
|
||||
if resetAt != nil {
|
||||
t.Errorf("expected nil for Anthropic headers, got %v", resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_UserProvidedScenario(t *testing.T) {
|
||||
// This is the exact scenario from the user:
|
||||
// codex_7d_used_percent: 100
|
||||
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
|
||||
// codex_5h_used_percent: 3
|
||||
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
|
||||
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate headers matching user's data
|
||||
// Note: We need to map the canonical 5h/7d back to primary/secondary
|
||||
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "384607")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days = 10080 minutes
|
||||
headers.Set("x-codex-secondary-used-percent", "3")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "17369")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours = 300 minutes
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt for user scenario")
|
||||
}
|
||||
|
||||
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
|
||||
expectedDuration := 384607 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
|
||||
// Verify it's approximately 4.45 days (384607 seconds)
|
||||
duration := resetAt.Sub(before)
|
||||
actualDays := duration.Hours() / 24.0
|
||||
|
||||
// 384607 / 86400 = ~4.45 days
|
||||
if actualDays < 4.4 || actualDays > 4.5 {
|
||||
t.Errorf("expected ~4.45 days, got %.2f days", actualDays)
|
||||
}
|
||||
|
||||
t.Logf("User scenario: reset_at=%v, duration=%.2f days", resetAt, actualDays)
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset(t *testing.T) {
|
||||
// Test that we return nil when there's used_percent but no reset_after_seconds
|
||||
// This should cause the caller to use the default 5-minute fallback
|
||||
|
||||
svc := &RateLimitService{}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
// No reset_after_seconds!
|
||||
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
|
||||
// Should return nil since there's no reset time available
|
||||
if resetAt != nil {
|
||||
t.Errorf("expected nil when no reset_after_seconds, got %v", resetAt)
|
||||
}
|
||||
}
|
||||
@@ -49,6 +49,11 @@ type RedeemCodeRepository interface {
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error)
|
||||
ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error)
|
||||
// ListByUserPaginated returns paginated balance/concurrency history for a specific user.
|
||||
// codeType filter is optional - pass empty string to return all types.
|
||||
ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error)
|
||||
// SumPositiveBalanceByUser returns the total recharged amount (sum of positive balance values) for a user.
|
||||
SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error)
|
||||
}
|
||||
|
||||
// GenerateCodesRequest 生成兑换码请求
|
||||
@@ -126,7 +131,8 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
||||
return nil, errors.New("count must be greater than 0")
|
||||
}
|
||||
|
||||
if req.Value <= 0 {
|
||||
// 邀请码类型不需要数值,其他类型需要
|
||||
if req.Type != RedeemTypeInvitation && req.Value <= 0 {
|
||||
return nil, errors.New("value must be greater than 0")
|
||||
}
|
||||
|
||||
@@ -139,6 +145,12 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
||||
codeType = RedeemTypeBalance
|
||||
}
|
||||
|
||||
// 邀请码类型的 value 设为 0
|
||||
value := req.Value
|
||||
if codeType == RedeemTypeInvitation {
|
||||
value = 0
|
||||
}
|
||||
|
||||
codes := make([]RedeemCode, 0, req.Count)
|
||||
for i := 0; i < req.Count; i++ {
|
||||
code, err := s.GenerateRandomCode()
|
||||
@@ -149,7 +161,7 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
||||
codes = append(codes, RedeemCode{
|
||||
Code: code,
|
||||
Type: codeType,
|
||||
Value: req.Value,
|
||||
Value: value,
|
||||
Status: StatusUnused,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,6 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyRegistrationEnabled,
|
||||
SettingKeyEmailVerifyEnabled,
|
||||
SettingKeyPromoCodeEnabled,
|
||||
SettingKeyPasswordResetEnabled,
|
||||
SettingKeyInvitationCodeEnabled,
|
||||
SettingKeyTotpEnabled,
|
||||
SettingKeyTurnstileEnabled,
|
||||
SettingKeyTurnstileSiteKey,
|
||||
SettingKeySiteName,
|
||||
@@ -71,6 +74,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyDocURL,
|
||||
SettingKeyHomeContent,
|
||||
SettingKeyHideCcsImportButton,
|
||||
SettingKeyPurchaseSubscriptionEnabled,
|
||||
SettingKeyPurchaseSubscriptionURL,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
}
|
||||
|
||||
@@ -86,21 +91,30 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
|
||||
}
|
||||
|
||||
// Password reset requires email verification to be enabled
|
||||
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
||||
passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true"
|
||||
|
||||
return &PublicSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: emailVerifyEnabled,
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
PasswordResetEnabled: passwordResetEnabled,
|
||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -125,37 +139,47 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
|
||||
// Return a struct that matches the frontend's expected format
|
||||
return &struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo,omitempty"`
|
||||
SiteSubtitle string `json:"site_subtitle,omitempty"`
|
||||
APIBaseURL string `json:"api_base_url,omitempty"`
|
||||
ContactInfo string `json:"contact_info,omitempty"`
|
||||
DocURL string `json:"doc_url,omitempty"`
|
||||
HomeContent string `json:"home_content,omitempty"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo,omitempty"`
|
||||
SiteSubtitle string `json:"site_subtitle,omitempty"`
|
||||
APIBaseURL string `json:"api_base_url,omitempty"`
|
||||
ContactInfo string `json:"contact_info,omitempty"`
|
||||
DocURL string `json:"doc_url,omitempty"`
|
||||
HomeContent string `json:"home_content,omitempty"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: s.version,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: s.version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -167,6 +191,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
|
||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
||||
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
|
||||
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
|
||||
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[SettingKeySMTPHost] = settings.SMTPHost
|
||||
@@ -203,6 +230,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyDocURL] = settings.DocURL
|
||||
updates[SettingKeyHomeContent] = settings.HomeContent
|
||||
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
|
||||
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
|
||||
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
|
||||
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
@@ -262,6 +291,44 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
|
||||
return value != "false"
|
||||
}
|
||||
|
||||
// IsInvitationCodeEnabled 检查是否启用邀请码注册功能
|
||||
func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyInvitationCodeEnabled)
|
||||
if err != nil {
|
||||
return false // 默认关闭
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsPasswordResetEnabled 检查是否启用密码重置功能
|
||||
// 要求:必须同时开启邮件验证
|
||||
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
|
||||
// Password reset requires email verification to be enabled
|
||||
if !s.IsEmailVerifyEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyPasswordResetEnabled)
|
||||
if err != nil {
|
||||
return false // 默认关闭
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsTotpEnabled 检查是否启用 TOTP 双因素认证功能
|
||||
func (s *SettingService) IsTotpEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyTotpEnabled)
|
||||
if err != nil {
|
||||
return false // 默认关闭
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsTotpEncryptionKeyConfigured 检查 TOTP 加密密钥是否已手动配置
|
||||
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
|
||||
func (s *SettingService) IsTotpEncryptionKeyConfigured() bool {
|
||||
return s.cfg.Totp.EncryptionKeyConfigured
|
||||
}
|
||||
|
||||
// GetSiteName 获取网站名称
|
||||
func (s *SettingService) GetSiteName(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
|
||||
@@ -309,15 +376,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
|
||||
// 初始化默认设置
|
||||
defaults := map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
||||
SettingKeySiteName: "Sub2API",
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
||||
SettingKeySiteName: "Sub2API",
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyPurchaseSubscriptionEnabled: "false",
|
||||
SettingKeyPurchaseSubscriptionURL: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
// Model fallback defaults
|
||||
SettingKeyEnableModelFallback: "false",
|
||||
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
|
||||
@@ -340,10 +409,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
|
||||
// parseSettings 解析设置到结构体
|
||||
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
|
||||
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
EmailVerifyEnabled: emailVerifyEnabled,
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
|
||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
@@ -361,6 +434,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user