mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-14 20:04:46 +08:00
Merge up/main
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -35,24 +36,25 @@ const (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
}
|
||||
|
||||
// UpdateConfig 在线更新相关配置
|
||||
@@ -322,6 +324,30 @@ type TurnstileConfig struct {
|
||||
Required bool `mapstructure:"required"`
|
||||
}
|
||||
|
||||
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。
|
||||
//
|
||||
// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。
|
||||
// 这里是用于登录 Sub2API 本身的用户体系。
|
||||
type LinuxDoConnectConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||
TokenURL string `mapstructure:"token_url"`
|
||||
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||
Scopes string `mapstructure:"scopes"`
|
||||
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
|
||||
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
|
||||
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
|
||||
UsePKCE bool `mapstructure:"use_pkce"`
|
||||
|
||||
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
|
||||
// 为空时,服务端会尝试一组常见字段名。
|
||||
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
|
||||
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
|
||||
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
||||
}
|
||||
|
||||
type DefaultConfig struct {
|
||||
AdminEmail string `mapstructure:"admin_email"`
|
||||
AdminPassword string `mapstructure:"admin_password"`
|
||||
@@ -388,6 +414,18 @@ func Load() (*Config, error) {
|
||||
cfg.Server.Mode = "debug"
|
||||
}
|
||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
|
||||
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
|
||||
cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL)
|
||||
cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL)
|
||||
cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL)
|
||||
cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes)
|
||||
cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL)
|
||||
cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL)
|
||||
cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod))
|
||||
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
|
||||
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
|
||||
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
|
||||
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||
@@ -426,6 +464,81 @@ func Load() (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。
|
||||
func ValidateAbsoluteHTTPURL(raw string) error {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return fmt.Errorf("empty url")
|
||||
}
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !u.IsAbs() {
|
||||
return fmt.Errorf("must be absolute")
|
||||
}
|
||||
if !isHTTPScheme(u.Scheme) {
|
||||
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||
}
|
||||
if strings.TrimSpace(u.Host) == "" {
|
||||
return fmt.Errorf("missing host")
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
return fmt.Errorf("must not include fragment")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateFrontendRedirectURL 校验前端回调地址:
|
||||
// - 允许同源相对路径(以 / 开头)
|
||||
// - 或绝对 http(s) URL(禁止 fragment)
|
||||
func ValidateFrontendRedirectURL(raw string) error {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return fmt.Errorf("empty url")
|
||||
}
|
||||
if strings.ContainsAny(raw, "\r\n") {
|
||||
return fmt.Errorf("contains invalid characters")
|
||||
}
|
||||
if strings.HasPrefix(raw, "/") {
|
||||
if strings.HasPrefix(raw, "//") {
|
||||
return fmt.Errorf("must not start with //")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !u.IsAbs() {
|
||||
return fmt.Errorf("must be absolute http(s) url or relative path")
|
||||
}
|
||||
if !isHTTPScheme(u.Scheme) {
|
||||
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||
}
|
||||
if strings.TrimSpace(u.Host) == "" {
|
||||
return fmt.Errorf("missing host")
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
return fmt.Errorf("must not include fragment")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isHTTPScheme(scheme string) bool {
|
||||
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
|
||||
}
|
||||
|
||||
func warnIfInsecureURL(field, raw string) {
|
||||
u, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(u.Scheme, "http") {
|
||||
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
|
||||
}
|
||||
}
|
||||
|
||||
func setDefaults() {
|
||||
viper.SetDefault("run_mode", RunModeStandard)
|
||||
|
||||
@@ -475,6 +588,22 @@ func setDefaults() {
|
||||
// Turnstile
|
||||
viper.SetDefault("turnstile.required", false)
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
viper.SetDefault("linuxdo_connect.enabled", false)
|
||||
viper.SetDefault("linuxdo_connect.client_id", "")
|
||||
viper.SetDefault("linuxdo_connect.client_secret", "")
|
||||
viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize")
|
||||
viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user")
|
||||
viper.SetDefault("linuxdo_connect.scopes", "user")
|
||||
viper.SetDefault("linuxdo_connect.redirect_url", "")
|
||||
viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback")
|
||||
viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post")
|
||||
viper.SetDefault("linuxdo_connect.use_pkce", false)
|
||||
viper.SetDefault("linuxdo_connect.userinfo_email_path", "")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
viper.SetDefault("database.port", 5432)
|
||||
@@ -544,7 +673,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
@@ -586,6 +715,60 @@ func (c *Config) Validate() error {
|
||||
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
|
||||
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
|
||||
}
|
||||
if c.LinuxDo.Enabled {
|
||||
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.TokenURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod))
|
||||
switch method {
|
||||
case "", "client_secret_post", "client_secret_basic", "none":
|
||||
default:
|
||||
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
|
||||
}
|
||||
if method == "none" && !c.LinuxDo.UsePKCE {
|
||||
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
|
||||
}
|
||||
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err)
|
||||
}
|
||||
|
||||
warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL)
|
||||
warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL)
|
||||
warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL)
|
||||
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
|
||||
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
|
||||
}
|
||||
if c.Billing.CircuitBreaker.Enabled {
|
||||
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
|
||||
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
|
||||
t.Fatalf("ResponseHeaders.Enabled = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.LinuxDo.Enabled = true
|
||||
cfg.LinuxDo.ClientID = "test-client"
|
||||
cfg.LinuxDo.ClientSecret = "test-secret"
|
||||
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
|
||||
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
|
||||
cfg.LinuxDo.UsePKCE = false
|
||||
|
||||
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for javascript scheme, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") {
|
||||
t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.LinuxDo.Enabled = true
|
||||
cfg.LinuxDo.ClientID = "test-client"
|
||||
cfg.LinuxDo.ClientSecret = ""
|
||||
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
|
||||
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
|
||||
cfg.LinuxDo.TokenAuthMethod = "none"
|
||||
cfg.LinuxDo.UsePKCE = false
|
||||
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
|
||||
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
@@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
accountType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
@@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.Status != "" ||
|
||||
req.Schedulable != nil ||
|
||||
req.GroupIDs != nil ||
|
||||
len(req.Credentials) > 0 ||
|
||||
len(req.Extra) > 0
|
||||
@@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: req.Status,
|
||||
Schedulable: req.Schedulable,
|
||||
GroupIDs: req.GroupIDs,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -67,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
platform := c.Query("platform")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
isExclusiveStr := c.Query("is_exclusive")
|
||||
|
||||
var isExclusive *bool
|
||||
@@ -75,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
isExclusive = &val
|
||||
}
|
||||
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
209
backend/internal/handler/admin/promo_handler.go
Normal file
209
backend/internal/handler/admin/promo_handler.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PromoHandler handles admin promo code management
|
||||
type PromoHandler struct {
|
||||
promoService *service.PromoService
|
||||
}
|
||||
|
||||
// NewPromoHandler creates a new admin promo handler
|
||||
func NewPromoHandler(promoService *service.PromoService) *PromoHandler {
|
||||
return &PromoHandler{
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePromoCodeRequest represents create promo code request
|
||||
type CreatePromoCodeRequest struct {
|
||||
Code string `json:"code"` // 可选,为空则自动生成
|
||||
BonusAmount float64 `json:"bonus_amount" binding:"required,min=0"` // 赠送余额
|
||||
MaxUses int `json:"max_uses" binding:"min=0"` // 最大使用次数,0=无限
|
||||
ExpiresAt *int64 `json:"expires_at"` // 过期时间戳(秒)
|
||||
Notes string `json:"notes"` // 备注
|
||||
}
|
||||
|
||||
// UpdatePromoCodeRequest represents update promo code request
|
||||
type UpdatePromoCodeRequest struct {
|
||||
Code *string `json:"code"`
|
||||
BonusAmount *float64 `json:"bonus_amount" binding:"omitempty,min=0"`
|
||||
MaxUses *int `json:"max_uses" binding:"omitempty,min=0"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
Notes *string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all promo codes with pagination
|
||||
// GET /api/v1/admin/promo-codes
|
||||
func (h *PromoHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := c.Query("status")
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.PromoCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.PromoCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a promo code by ID
|
||||
// GET /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) GetByID(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.promoService.GetByID(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Create handles creating a new promo code
|
||||
// POST /api/v1/admin/promo-codes
|
||||
func (h *PromoHandler) Create(c *gin.Context) {
|
||||
var req CreatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.CreatePromoCodeInput{
|
||||
Code: req.Code,
|
||||
BonusAmount: req.BonusAmount,
|
||||
MaxUses: req.MaxUses,
|
||||
Notes: req.Notes,
|
||||
}
|
||||
|
||||
if req.ExpiresAt != nil {
|
||||
t := time.Unix(*req.ExpiresAt, 0)
|
||||
input.ExpiresAt = &t
|
||||
}
|
||||
|
||||
code, err := h.promoService.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Update handles updating a promo code
|
||||
// PUT /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) Update(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.UpdatePromoCodeInput{
|
||||
Code: req.Code,
|
||||
BonusAmount: req.BonusAmount,
|
||||
MaxUses: req.MaxUses,
|
||||
Status: req.Status,
|
||||
Notes: req.Notes,
|
||||
}
|
||||
|
||||
if req.ExpiresAt != nil {
|
||||
if *req.ExpiresAt == 0 {
|
||||
// 0 表示清除过期时间
|
||||
input.ExpiresAt = nil
|
||||
} else {
|
||||
t := time.Unix(*req.ExpiresAt, 0)
|
||||
input.ExpiresAt = &t
|
||||
}
|
||||
}
|
||||
|
||||
code, err := h.promoService.Update(c.Request.Context(), codeID, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Delete handles deleting a promo code
|
||||
// DELETE /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) Delete(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.promoService.Delete(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Promo code deleted successfully"})
|
||||
}
|
||||
|
||||
// GetUsages handles getting usage records for a promo code
|
||||
// GET /api/v1/admin/promo-codes/:id/usages
|
||||
func (h *PromoHandler) GetUsages(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
usages, paginationResult, err := h.promoService.ListUsages(c.Request.Context(), codeID, params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.PromoCodeUsage, 0, len(usages))
|
||||
for i := range usages {
|
||||
out = append(out, *dto.PromoCodeUsageFromService(&usages[i]))
|
||||
}
|
||||
response.Paginated(c, out, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
@@ -51,6 +51,11 @@ func (h *ProxyHandler) List(c *gin.Context) {
|
||||
protocol := c.Query("protocol")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
codeType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,8 +2,10 @@ package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -38,33 +40,37 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: settings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||
LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: settings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -88,6 +94,12 @@ type UpdateSettingsRequest struct {
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
@@ -165,34 +177,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// LinuxDo Connect 参数验证
|
||||
if req.LinuxDoConnectEnabled {
|
||||
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||
req.LinuxDoConnectClientSecret = strings.TrimSpace(req.LinuxDoConnectClientSecret)
|
||||
req.LinuxDoConnectRedirectURL = strings.TrimSpace(req.LinuxDoConnectRedirectURL)
|
||||
|
||||
if req.LinuxDoConnectClientID == "" {
|
||||
response.BadRequest(c, "LinuxDo Client ID is required when enabled")
|
||||
return
|
||||
}
|
||||
if req.LinuxDoConnectRedirectURL == "" {
|
||||
response.BadRequest(c, "LinuxDo Redirect URL is required when enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.LinuxDoConnectRedirectURL); err != nil {
|
||||
response.BadRequest(c, "LinuxDo Redirect URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 如果未提供 client_secret,则保留现有值(如有)。
|
||||
if req.LinuxDoConnectClientSecret == "" {
|
||||
if previousSettings.LinuxDoConnectClientSecret == "" {
|
||||
response.BadRequest(c, "LinuxDo Client Secret is required when enabled")
|
||||
return
|
||||
}
|
||||
req.LinuxDoConnectClientSecret = previousSettings.LinuxDoConnectClientSecret
|
||||
}
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
@@ -210,33 +255,37 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||
LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if req.TurnstileSecretKey != "" {
|
||||
changed = append(changed, "turnstile_secret_key")
|
||||
}
|
||||
if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled {
|
||||
changed = append(changed, "linuxdo_connect_enabled")
|
||||
}
|
||||
if before.LinuxDoConnectClientID != after.LinuxDoConnectClientID {
|
||||
changed = append(changed, "linuxdo_connect_client_id")
|
||||
}
|
||||
if req.LinuxDoConnectClientSecret != "" {
|
||||
changed = append(changed, "linuxdo_connect_client_secret")
|
||||
}
|
||||
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
|
||||
changed = append(changed, "linuxdo_connect_redirect_url")
|
||||
}
|
||||
if before.SiteName != after.SiteName {
|
||||
changed = append(changed, "site_name")
|
||||
}
|
||||
@@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
|
||||
changed = append(changed, "fallback_model_antigravity")
|
||||
}
|
||||
if before.EnableIdentityPatch != after.EnableIdentityPatch {
|
||||
changed = append(changed, "enable_identity_patch")
|
||||
}
|
||||
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
|
||||
changed = append(changed, "identity_patch_prompt")
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct {
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
filters := service.UserListFilters{
|
||||
Status: c.Query("status"),
|
||||
Role: c.Query("role"),
|
||||
Search: c.Query("search"),
|
||||
Search: search,
|
||||
Attributes: parseAttributeFilters(c),
|
||||
}
|
||||
|
||||
|
||||
@@ -27,16 +27,20 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
|
||||
|
||||
// CreateAPIKeyRequest represents the create API key request payload
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
GroupID *int64 `json:"group_id"` // nullable
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
Name string `json:"name" binding:"required"`
|
||||
GroupID *int64 `json:"group_id"` // nullable
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest represents the update API key request payload
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
}
|
||||
|
||||
// List handles listing user's API keys with pagination
|
||||
@@ -110,9 +114,11 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
svcReq := service.CreateAPIKeyRequest{
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateAPIKeyRequest{}
|
||||
svcReq := service.UpdateAPIKeyRequest{
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
}
|
||||
if req.Name != "" {
|
||||
svcReq.Name = &req.Name
|
||||
}
|
||||
|
||||
@@ -12,17 +12,21 @@ import (
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +36,7 @@ type RegisterRequest struct {
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
VerifyCode string `json:"verify_code"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
PromoCode string `json:"promo_code"` // 注册优惠码
|
||||
}
|
||||
|
||||
// SendVerifyCodeRequest 发送验证码请求
|
||||
@@ -77,7 +82,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -172,3 +177,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
|
||||
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
|
||||
}
|
||||
|
||||
// ValidatePromoCodeRequest 验证优惠码请求
|
||||
type ValidatePromoCodeRequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// ValidatePromoCodeResponse 验证优惠码响应
|
||||
type ValidatePromoCodeResponse struct {
|
||||
Valid bool `json:"valid"`
|
||||
BonusAmount float64 `json:"bonus_amount,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
|
||||
// POST /api/v1/auth/validate-promo-code
|
||||
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
var req ValidatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
promoCode, err := h.promoService.ValidatePromoCode(c.Request.Context(), req.Code)
|
||||
if err != nil {
|
||||
// 根据错误类型返回对应的错误码
|
||||
errorCode := "PROMO_CODE_INVALID"
|
||||
switch err {
|
||||
case service.ErrPromoCodeNotFound:
|
||||
errorCode = "PROMO_CODE_NOT_FOUND"
|
||||
case service.ErrPromoCodeExpired:
|
||||
errorCode = "PROMO_CODE_EXPIRED"
|
||||
case service.ErrPromoCodeDisabled:
|
||||
errorCode = "PROMO_CODE_DISABLED"
|
||||
case service.ErrPromoCodeMaxUsed:
|
||||
errorCode = "PROMO_CODE_MAX_USED"
|
||||
case service.ErrPromoCodeAlreadyUsed:
|
||||
errorCode = "PROMO_CODE_ALREADY_USED"
|
||||
}
|
||||
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: errorCode,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if promoCode == nil {
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "PROMO_CODE_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: true,
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
})
|
||||
}
|
||||
|
||||
679
backend/internal/handler/auth_linuxdo_oauth.go
Normal file
679
backend/internal/handler/auth_linuxdo_oauth.go
Normal file
@@ -0,0 +1,679 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
|
||||
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
|
||||
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
|
||||
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
|
||||
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
|
||||
linuxDoOAuthDefaultRedirectTo = "/dashboard"
|
||||
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
|
||||
|
||||
linuxDoOAuthMaxRedirectLen = 2048
|
||||
linuxDoOAuthMaxFragmentValueLen = 512
|
||||
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
|
||||
)
|
||||
|
||||
type linuxDoTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
type linuxDoTokenExchangeError struct {
|
||||
StatusCode int
|
||||
ProviderError string
|
||||
ProviderDescription string
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *linuxDoTokenExchangeError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
|
||||
if strings.TrimSpace(e.ProviderError) != "" {
|
||||
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
|
||||
}
|
||||
if strings.TrimSpace(e.ProviderDescription) != "" {
|
||||
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。
|
||||
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
|
||||
func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
|
||||
cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
state, err := oauth.GenerateState()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
|
||||
if redirectTo == "" {
|
||||
redirectTo = linuxDoOAuthDefaultRedirectTo
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
|
||||
codeChallenge := ""
|
||||
if cfg.UsePKCE {
|
||||
verifier, err := oauth.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
|
||||
return
|
||||
}
|
||||
codeChallenge = oauth.GenerateCodeChallenge(verifier)
|
||||
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
}
|
||||
|
||||
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||
if redirectURI == "" {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
|
||||
return
|
||||
}
|
||||
|
||||
authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。
|
||||
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
|
||||
func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context())
|
||||
if cfgErr != nil {
|
||||
response.ErrorFrom(c, cfgErr)
|
||||
return
|
||||
}
|
||||
|
||||
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
|
||||
if frontendCallback == "" {
|
||||
frontendCallback = linuxDoOAuthDefaultFrontendCB
|
||||
}
|
||||
|
||||
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
|
||||
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(c.Query("code"))
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if code == "" || state == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
|
||||
return
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
defer func() {
|
||||
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
|
||||
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
|
||||
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
|
||||
}()
|
||||
|
||||
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
|
||||
if err != nil || expectedState == "" || state != expectedState {
|
||||
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
|
||||
return
|
||||
}
|
||||
|
||||
redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie)
|
||||
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
|
||||
if redirectTo == "" {
|
||||
redirectTo = linuxDoOAuthDefaultRedirectTo
|
||||
}
|
||||
|
||||
codeVerifier := ""
|
||||
if cfg.UsePKCE {
|
||||
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
|
||||
if codeVerifier == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||
if redirectURI == "" {
|
||||
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
|
||||
return
|
||||
}
|
||||
|
||||
tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
|
||||
if err != nil {
|
||||
description := ""
|
||||
var exchangeErr *linuxDoTokenExchangeError
|
||||
if errors.As(err, &exchangeErr) && exchangeErr != nil {
|
||||
log.Printf(
|
||||
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
|
||||
exchangeErr.StatusCode,
|
||||
exchangeErr.ProviderError,
|
||||
exchangeErr.ProviderDescription,
|
||||
truncateLogValue(exchangeErr.Body, 2048),
|
||||
)
|
||||
description = exchangeErr.Error()
|
||||
} else {
|
||||
log.Printf("[LinuxDo OAuth] token exchange failed: %v", err)
|
||||
description = err.Error()
|
||||
}
|
||||
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
|
||||
return
|
||||
}
|
||||
|
||||
email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
|
||||
if err != nil {
|
||||
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
|
||||
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
|
||||
return
|
||||
}
|
||||
|
||||
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
|
||||
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
|
||||
if subject != "" {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
|
||||
if err != nil {
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
|
||||
fragment := url.Values{}
|
||||
fragment.Set("access_token", jwtToken)
|
||||
fragment.Set("token_type", "Bearer")
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||
}
|
||||
if h == nil || h.cfg == nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||
}
|
||||
if !h.cfg.LinuxDo.Enabled {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||
}
|
||||
return h.cfg.LinuxDo, nil
|
||||
}
|
||||
|
||||
func linuxDoExchangeCode(
|
||||
ctx context.Context,
|
||||
cfg config.LinuxDoConnectConfig,
|
||||
code string,
|
||||
redirectURI string,
|
||||
codeVerifier string,
|
||||
) (*linuxDoTokenResponse, error) {
|
||||
client := req.C().SetTimeout(30 * time.Second)
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("client_id", cfg.ClientID)
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
if cfg.UsePKCE {
|
||||
form.Set("code_verifier", codeVerifier)
|
||||
}
|
||||
|
||||
r := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json")
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
|
||||
case "", "client_secret_post":
|
||||
form.Set("client_secret", cfg.ClientSecret)
|
||||
case "client_secret_basic":
|
||||
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||
case "none":
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
|
||||
}
|
||||
|
||||
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request token: %w", err)
|
||||
}
|
||||
body := strings.TrimSpace(resp.String())
|
||||
if !resp.IsSuccessState() {
|
||||
providerErr, providerDesc := parseOAuthProviderError(body)
|
||||
return nil, &linuxDoTokenExchangeError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ProviderError: providerErr,
|
||||
ProviderDescription: providerDesc,
|
||||
Body: body,
|
||||
}
|
||||
}
|
||||
|
||||
tokenResp, ok := parseLinuxDoTokenResponse(body)
|
||||
if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||
return nil, &linuxDoTokenExchangeError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: body,
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(tokenResp.TokenType) == "" {
|
||||
tokenResp.TokenType = "Bearer"
|
||||
}
|
||||
return tokenResp, nil
|
||||
}
|
||||
|
||||
func linuxDoFetchUserInfo(
|
||||
ctx context.Context,
|
||||
cfg config.LinuxDoConnectConfig,
|
||||
token *linuxDoTokenResponse,
|
||||
) (email string, username string, subject string, err error) {
|
||||
client := req.C().SetTimeout(30 * time.Second)
|
||||
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetHeader("Authorization", authorization).
|
||||
Get(cfg.UserInfoURL)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("request userinfo: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return linuxDoParseUserInfo(resp.String(), cfg)
|
||||
}
|
||||
|
||||
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
|
||||
email = firstNonEmpty(
|
||||
getGJSON(body, cfg.UserInfoEmailPath),
|
||||
getGJSON(body, "email"),
|
||||
getGJSON(body, "user.email"),
|
||||
getGJSON(body, "data.email"),
|
||||
getGJSON(body, "attributes.email"),
|
||||
)
|
||||
username = firstNonEmpty(
|
||||
getGJSON(body, cfg.UserInfoUsernamePath),
|
||||
getGJSON(body, "username"),
|
||||
getGJSON(body, "preferred_username"),
|
||||
getGJSON(body, "name"),
|
||||
getGJSON(body, "user.username"),
|
||||
getGJSON(body, "user.name"),
|
||||
)
|
||||
subject = firstNonEmpty(
|
||||
getGJSON(body, cfg.UserInfoIDPath),
|
||||
getGJSON(body, "sub"),
|
||||
getGJSON(body, "id"),
|
||||
getGJSON(body, "user_id"),
|
||||
getGJSON(body, "uid"),
|
||||
getGJSON(body, "user.id"),
|
||||
)
|
||||
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return "", "", "", errors.New("userinfo missing id field")
|
||||
}
|
||||
if !isSafeLinuxDoSubject(subject) {
|
||||
return "", "", "", errors.New("userinfo returned invalid id field")
|
||||
}
|
||||
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" {
|
||||
// LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
username = strings.TrimSpace(username)
|
||||
if username == "" {
|
||||
username = "linuxdo_" + subject
|
||||
}
|
||||
|
||||
return email, username, subject, nil
|
||||
}
|
||||
|
||||
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
|
||||
u, err := url.Parse(cfg.AuthorizeURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse authorize_url: %w", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("response_type", "code")
|
||||
q.Set("client_id", cfg.ClientID)
|
||||
q.Set("redirect_uri", redirectURI)
|
||||
if strings.TrimSpace(cfg.Scopes) != "" {
|
||||
q.Set("scope", cfg.Scopes)
|
||||
}
|
||||
q.Set("state", state)
|
||||
if cfg.UsePKCE {
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) {
|
||||
fragment := url.Values{}
|
||||
fragment.Set("error", truncateFragmentValue(code))
|
||||
if strings.TrimSpace(message) != "" {
|
||||
fragment.Set("error_message", truncateFragmentValue(message))
|
||||
}
|
||||
if strings.TrimSpace(description) != "" {
|
||||
fragment.Set("error_description", truncateFragmentValue(description))
|
||||
}
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) {
|
||||
u, err := url.Parse(frontendCallback)
|
||||
if err != nil {
|
||||
// 兜底:尽力跳转到默认页面,避免卡死在回调页。
|
||||
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
|
||||
return
|
||||
}
|
||||
if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
|
||||
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
|
||||
return
|
||||
}
|
||||
u.Fragment = fragment.Encode()
|
||||
c.Header("Cache-Control", "no-store")
|
||||
c.Header("Pragma", "no-cache")
|
||||
c.Redirect(http.StatusFound, u.String())
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseOAuthProviderError(body string) (providerErr string, providerDesc string) {
|
||||
body = strings.TrimSpace(body)
|
||||
if body == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
providerErr = firstNonEmpty(
|
||||
getGJSON(body, "error"),
|
||||
getGJSON(body, "code"),
|
||||
getGJSON(body, "error.code"),
|
||||
)
|
||||
providerDesc = firstNonEmpty(
|
||||
getGJSON(body, "error_description"),
|
||||
getGJSON(body, "error.message"),
|
||||
getGJSON(body, "message"),
|
||||
getGJSON(body, "detail"),
|
||||
)
|
||||
|
||||
if providerErr != "" || providerDesc != "" {
|
||||
return providerErr, providerDesc
|
||||
}
|
||||
|
||||
values, err := url.ParseQuery(body)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
providerErr = firstNonEmpty(values.Get("error"), values.Get("code"))
|
||||
providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message"))
|
||||
return providerErr, providerDesc
|
||||
}
|
||||
|
||||
func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) {
|
||||
body = strings.TrimSpace(body)
|
||||
if body == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
|
||||
if accessToken != "" {
|
||||
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
|
||||
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
|
||||
scope := strings.TrimSpace(getGJSON(body, "scope"))
|
||||
expiresIn := gjson.Get(body, "expires_in").Int()
|
||||
return &linuxDoTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
TokenType: tokenType,
|
||||
ExpiresIn: expiresIn,
|
||||
RefreshToken: refreshToken,
|
||||
Scope: scope,
|
||||
}, true
|
||||
}
|
||||
|
||||
values, err := url.ParseQuery(body)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
accessToken = strings.TrimSpace(values.Get("access_token"))
|
||||
if accessToken == "" {
|
||||
return nil, false
|
||||
}
|
||||
expiresIn := int64(0)
|
||||
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
|
||||
if v, err := strconv.ParseInt(raw, 10, 64); err == nil {
|
||||
expiresIn = v
|
||||
}
|
||||
}
|
||||
return &linuxDoTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
TokenType: strings.TrimSpace(values.Get("token_type")),
|
||||
ExpiresIn: expiresIn,
|
||||
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
|
||||
Scope: strings.TrimSpace(values.Get("scope")),
|
||||
}, true
|
||||
}
|
||||
|
||||
func getGJSON(body string, path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
res := gjson.Get(body, path)
|
||||
if !res.Exists() {
|
||||
return ""
|
||||
}
|
||||
return res.String()
|
||||
}
|
||||
|
||||
func truncateLogValue(value string, maxLen int) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" || maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(value) <= maxLen {
|
||||
return value
|
||||
}
|
||||
value = value[:maxLen]
|
||||
for !utf8.ValidString(value) {
|
||||
value = value[:len(value)-1]
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func singleLine(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(strings.Fields(value), " ")
|
||||
}
|
||||
|
||||
func sanitizeFrontendRedirectPath(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if len(path) > linuxDoOAuthMaxRedirectLen {
|
||||
return ""
|
||||
}
|
||||
// 只允许同源相对路径(避免开放重定向)。
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(path, "//") {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(path, "://") {
|
||||
return ""
|
||||
}
|
||||
if strings.ContainsAny(path, "\r\n") {
|
||||
return ""
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func isRequestHTTPS(c *gin.Context) bool {
|
||||
if c.Request.TLS != nil {
|
||||
return true
|
||||
}
|
||||
proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")))
|
||||
return proto == "https"
|
||||
}
|
||||
|
||||
func encodeCookieValue(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func decodeCookieValue(value string) (string, error) {
|
||||
raw, err := base64.RawURLEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(raw), nil
|
||||
}
|
||||
|
||||
func readCookieDecoded(c *gin.Context, name string) (string, error) {
|
||||
ck, err := c.Request.Cookie(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return decodeCookieValue(ck.Value)
|
||||
}
|
||||
|
||||
func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: linuxDoOAuthCookiePath,
|
||||
MaxAge: maxAgeSec,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func clearCookie(c *gin.Context, name string, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: linuxDoOAuthCookiePath,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func truncateFragmentValue(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
if len(value) > linuxDoOAuthMaxFragmentValueLen {
|
||||
value = value[:linuxDoOAuthMaxFragmentValueLen]
|
||||
for !utf8.ValidString(value) {
|
||||
value = value[:len(value)-1]
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func buildBearerAuthorization(tokenType, accessToken string) (string, error) {
|
||||
tokenType = strings.TrimSpace(tokenType)
|
||||
if tokenType == "" {
|
||||
tokenType = "Bearer"
|
||||
}
|
||||
if !strings.EqualFold(tokenType, "Bearer") {
|
||||
return "", fmt.Errorf("unsupported token_type: %s", tokenType)
|
||||
}
|
||||
|
||||
accessToken = strings.TrimSpace(accessToken)
|
||||
if accessToken == "" {
|
||||
return "", errors.New("missing access_token")
|
||||
}
|
||||
if strings.ContainsAny(accessToken, " \t\r\n") {
|
||||
return "", errors.New("access_token contains whitespace")
|
||||
}
|
||||
return "Bearer " + accessToken, nil
|
||||
}
|
||||
|
||||
func isSafeLinuxDoSubject(subject string) bool {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen {
|
||||
return false
|
||||
}
|
||||
for _, r := range subject {
|
||||
switch {
|
||||
case r >= '0' && r <= '9':
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r == '_' || r == '-':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func linuxDoSyntheticEmail(subject string) string {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return ""
|
||||
}
|
||||
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
|
||||
}
|
||||
108
backend/internal/handler/auth_linuxdo_oauth_test.go
Normal file
108
backend/internal/handler/auth_linuxdo_oauth_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSanitizeFrontendRedirectPath(t *testing.T) {
|
||||
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard"))
|
||||
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard "))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard"))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com"))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com"))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo"))
|
||||
|
||||
long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen)
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath(long))
|
||||
}
|
||||
|
||||
func TestBuildBearerAuthorization(t *testing.T) {
|
||||
auth, err := buildBearerAuthorization("", "token123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Bearer token123", auth)
|
||||
|
||||
auth, err = buildBearerAuthorization("bearer", "token123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Bearer token123", auth)
|
||||
|
||||
_, err = buildBearerAuthorization("MAC", "token123")
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = buildBearerAuthorization("Bearer", "token 123")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
|
||||
cfg := config.LinuxDoConnectConfig{
|
||||
UserInfoURL: "https://connect.linux.do/api/user",
|
||||
}
|
||||
|
||||
email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "123", subject)
|
||||
require.Equal(t, "alice", username)
|
||||
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
|
||||
}
|
||||
|
||||
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
|
||||
cfg := config.LinuxDoConnectConfig{
|
||||
UserInfoURL: "https://connect.linux.do/api/user",
|
||||
}
|
||||
|
||||
email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "123", subject)
|
||||
require.Equal(t, "linuxdo_123", username)
|
||||
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
|
||||
}
|
||||
|
||||
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
|
||||
cfg := config.LinuxDoConnectConfig{
|
||||
UserInfoURL: "https://connect.linux.do/api/user",
|
||||
}
|
||||
|
||||
_, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
|
||||
require.Error(t, err)
|
||||
|
||||
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
|
||||
_, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseOAuthProviderErrorJSON(t *testing.T) {
|
||||
code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`)
|
||||
require.Equal(t, "invalid_client", code)
|
||||
require.Equal(t, "bad secret", desc)
|
||||
}
|
||||
|
||||
func TestParseOAuthProviderErrorForm(t *testing.T) {
|
||||
code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier")
|
||||
require.Equal(t, "invalid_request", code)
|
||||
require.Equal(t, "Missing code_verifier", desc)
|
||||
}
|
||||
|
||||
func TestParseLinuxDoTokenResponseJSON(t *testing.T) {
|
||||
token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "t1", token.AccessToken)
|
||||
require.Equal(t, "Bearer", token.TokenType)
|
||||
require.Equal(t, int64(3600), token.ExpiresIn)
|
||||
require.Equal(t, "user", token.Scope)
|
||||
}
|
||||
|
||||
func TestParseLinuxDoTokenResponseForm(t *testing.T) {
|
||||
token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "t2", token.AccessToken)
|
||||
require.Equal(t, "bearer", token.TokenType)
|
||||
require.Equal(t, int64(60), token.ExpiresIn)
|
||||
}
|
||||
|
||||
func TestSingleLineStripsWhitespace(t *testing.T) {
|
||||
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
|
||||
require.Equal(t, "", singleLine("\n\t\r"))
|
||||
}
|
||||
@@ -53,16 +53,18 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
IPWhitelist: k.IPWhitelist,
|
||||
IPBlacklist: k.IPBlacklist,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,11 +252,12 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
||||
|
||||
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
|
||||
// The account parameter allows caller to control what Account info is included.
|
||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *UsageLog {
|
||||
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
|
||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageLog{
|
||||
result := &UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
APIKeyID: l.APIKeyID,
|
||||
@@ -290,21 +293,26 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
}
|
||||
// IP 地址仅对管理员可见
|
||||
if includeIPAddress {
|
||||
result.IPAddress = l.IPAddress
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
||||
// It excludes Account details - users should not see account information.
|
||||
// It excludes Account details and IP address - users should not see these.
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
return usageLogFromServiceBase(l, nil)
|
||||
return usageLogFromServiceBase(l, nil, false)
|
||||
}
|
||||
|
||||
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
||||
// It includes minimal Account info (ID, Name only).
|
||||
// It includes minimal Account info (ID, Name only) and IP address.
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account))
|
||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
@@ -362,3 +370,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
Errors: r.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
func PromoCodeFromService(pc *service.PromoCode) *PromoCode {
|
||||
if pc == nil {
|
||||
return nil
|
||||
}
|
||||
return &PromoCode{
|
||||
ID: pc.ID,
|
||||
Code: pc.Code,
|
||||
BonusAmount: pc.BonusAmount,
|
||||
MaxUses: pc.MaxUses,
|
||||
UsedCount: pc.UsedCount,
|
||||
Status: pc.Status,
|
||||
ExpiresAt: pc.ExpiresAt,
|
||||
Notes: pc.Notes,
|
||||
CreatedAt: pc.CreatedAt,
|
||||
UpdatedAt: pc.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func PromoCodeUsageFromService(u *service.PromoCodeUsage) *PromoCodeUsage {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &PromoCodeUsage{
|
||||
ID: u.ID,
|
||||
PromoCodeID: u.PromoCodeID,
|
||||
UserID: u.UserID,
|
||||
BonusAmount: u.BonusAmount,
|
||||
UsedAt: u.UsedAt,
|
||||
User: UserFromServiceShallow(u.User),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ type SystemSettings struct {
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
|
||||
|
||||
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
@@ -50,5 +55,6 @@ type PublicSettings struct {
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -20,14 +20,16 @@ type User struct {
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist"`
|
||||
IPBlacklist []string `json:"ip_blacklist"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
@@ -187,6 +189,9 @@ type UsageLog struct {
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
// IP 地址(仅管理员可见)
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
@@ -245,3 +250,28 @@ type BulkAssignResult struct {
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
type PromoCode struct {
|
||||
ID int64 `json:"id"`
|
||||
Code string `json:"code"`
|
||||
BonusAmount float64 `json:"bonus_amount"`
|
||||
MaxUses int `json:"max_uses"`
|
||||
UsedCount int `json:"used_count"`
|
||||
Status string `json:"status"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// PromoCodeUsage 优惠码使用记录
|
||||
type PromoCodeUsage struct {
|
||||
ID int64 `json:"id"`
|
||||
PromoCodeID int64 `json:"promo_code_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
BonusAmount float64 `json:"bonus_amount"`
|
||||
UsedAt time.Time `json:"used_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -114,6 +115,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 获取 User-Agent
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// 获取客户端 IP
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
@@ -273,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -283,10 +287,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: cip,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -401,7 +406,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -411,10 +416,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: cip,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -167,6 +168,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 获取 User-Agent
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// 获取客户端 IP
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// For Gemini native API, do not send Claude-style ping frames.
|
||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
|
||||
|
||||
@@ -307,7 +311,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -317,10 +321,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: cip,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ type AdminHandlers struct {
|
||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Promo *admin.PromoHandler
|
||||
Setting *admin.SettingHandler
|
||||
System *admin.SystemHandler
|
||||
Subscription *admin.SubscriptionHandler
|
||||
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -93,6 +96,24 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
|
||||
// 获取客户端 IP
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
if !openai.IsCodexCLIRequest(userAgent) {
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
if strings.TrimSpace(existingInstructions) == "" {
|
||||
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
|
||||
reqBody["instructions"] = instructions
|
||||
// Re-serialize body
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
@@ -231,7 +252,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
@@ -241,10 +262,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: cip,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ func ProvideAdminHandlers(
|
||||
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
promoHandler *admin.PromoHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
systemHandler *admin.SystemHandler,
|
||||
subscriptionHandler *admin.SubscriptionHandler,
|
||||
@@ -36,6 +37,7 @@ func ProvideAdminHandlers(
|
||||
AntigravityOAuth: antigravityOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Promo: promoHandler,
|
||||
Setting: settingHandler,
|
||||
System: systemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
@@ -105,6 +107,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAntigravityOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewPromoHandler,
|
||||
admin.NewSettingHandler,
|
||||
ProvideSystemHandler,
|
||||
admin.NewSubscriptionHandler,
|
||||
|
||||
60
backend/internal/middleware/rate_limiter.go
Normal file
60
backend/internal/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RateLimiter Redis 速率限制器
|
||||
type RateLimiter struct {
|
||||
redis *redis.Client
|
||||
prefix string
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建速率限制器实例
|
||||
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
redis: redisClient,
|
||||
prefix: "rate_limit:",
|
||||
}
|
||||
}
|
||||
|
||||
// Limit 返回速率限制中间件
|
||||
// key: 限制类型标识
|
||||
// limit: 时间窗口内最大请求数
|
||||
// window: 时间窗口
|
||||
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
redisKey := r.prefix + key + ":" + ip
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 使用 INCR 原子操作增加计数
|
||||
count, err := r.redis.Incr(ctx, redisKey).Result()
|
||||
if err != nil {
|
||||
// Redis 错误时放行,避免影响正常服务
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 首次访问时设置过期时间
|
||||
if count == 1 {
|
||||
r.redis.Expire(ctx, redisKey, window)
|
||||
}
|
||||
|
||||
// 超过限制
|
||||
if count > int64(limit) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
"message": "Too many requests, please try again later",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -9,4 +9,6 @@ const (
|
||||
ForcePlatform Key = "ctx_force_platform"
|
||||
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
|
||||
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
|
||||
// Group 认证后的分组信息,由 API Key 认证中间件设置
|
||||
Group Key = "ctx_group"
|
||||
)
|
||||
|
||||
168
backend/internal/pkg/ip/ip.go
Normal file
168
backend/internal/pkg/ip/ip.go
Normal file
@@ -0,0 +1,168 @@
|
||||
// Package ip 提供客户端 IP 地址提取工具。
|
||||
package ip
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetClientIP 从 Gin Context 中提取客户端真实 IP 地址。
|
||||
// 按以下优先级检查 Header:
|
||||
// 1. CF-Connecting-IP (Cloudflare)
|
||||
// 2. X-Real-IP (Nginx)
|
||||
// 3. X-Forwarded-For (取第一个非私有 IP)
|
||||
// 4. c.ClientIP() (Gin 内置方法)
|
||||
func GetClientIP(c *gin.Context) string {
|
||||
// 1. Cloudflare
|
||||
if ip := c.GetHeader("CF-Connecting-IP"); ip != "" {
|
||||
return normalizeIP(ip)
|
||||
}
|
||||
|
||||
// 2. Nginx X-Real-IP
|
||||
if ip := c.GetHeader("X-Real-IP"); ip != "" {
|
||||
return normalizeIP(ip)
|
||||
}
|
||||
|
||||
// 3. X-Forwarded-For (多个 IP 时取第一个公网 IP)
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
for _, ip := range ips {
|
||||
ip = strings.TrimSpace(ip)
|
||||
if ip != "" && !isPrivateIP(ip) {
|
||||
return normalizeIP(ip)
|
||||
}
|
||||
}
|
||||
// 如果都是私有 IP,返回第一个
|
||||
if len(ips) > 0 {
|
||||
return normalizeIP(strings.TrimSpace(ips[0]))
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Gin 内置方法
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
||||
func normalizeIP(ip string) string {
|
||||
ip = strings.TrimSpace(ip)
|
||||
// 移除端口号(如 "192.168.1.1:8080" -> "192.168.1.1")
|
||||
if host, _, err := net.SplitHostPort(ip); err == nil {
|
||||
return host
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 私有 IP 范围
|
||||
privateBlocks := []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"127.0.0.0/8",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
}
|
||||
|
||||
for _, block := range privateBlocks {
|
||||
_, cidr, err := net.ParseCIDR(block)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MatchesPattern 检查 IP 是否匹配指定的模式(支持单个 IP 或 CIDR)。
|
||||
// pattern 可以是:
|
||||
// - 单个 IP: "192.168.1.100"
|
||||
// - CIDR 范围: "192.168.1.0/24"
|
||||
func MatchesPattern(clientIP, pattern string) bool {
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 尝试解析为 CIDR
|
||||
if strings.Contains(pattern, "/") {
|
||||
_, cidr, err := net.ParseCIDR(pattern)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cidr.Contains(ip)
|
||||
}
|
||||
|
||||
// 作为单个 IP 处理
|
||||
patternIP := net.ParseIP(pattern)
|
||||
if patternIP == nil {
|
||||
return false
|
||||
}
|
||||
return ip.Equal(patternIP)
|
||||
}
|
||||
|
||||
// MatchesAnyPattern 检查 IP 是否匹配任意一个模式。
|
||||
func MatchesAnyPattern(clientIP string, patterns []string) bool {
|
||||
for _, pattern := range patterns {
|
||||
if MatchesPattern(clientIP, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CheckIPRestriction 检查 IP 是否被 API Key 的 IP 限制允许。
|
||||
// 返回值:(是否允许, 拒绝原因)
|
||||
// 逻辑:
|
||||
// 1. 先检查黑名单,如果在黑名单中则直接拒绝
|
||||
// 2. 如果白名单不为空,IP 必须在白名单中
|
||||
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
|
||||
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
|
||||
// 规范化 IP
|
||||
clientIP = normalizeIP(clientIP)
|
||||
if clientIP == "" {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
// 1. 检查黑名单
|
||||
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
|
||||
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// ValidateIPPattern 验证 IP 或 CIDR 格式是否有效。
|
||||
func ValidateIPPattern(pattern string) bool {
|
||||
if strings.Contains(pattern, "/") {
|
||||
_, _, err := net.ParseCIDR(pattern)
|
||||
return err == nil
|
||||
}
|
||||
return net.ParseIP(pattern) != nil
|
||||
}
|
||||
|
||||
// ValidateIPPatterns 验证多个 IP 或 CIDR 格式。
|
||||
// 返回无效的模式列表。
|
||||
func ValidateIPPatterns(patterns []string) []string {
|
||||
var invalid []string
|
||||
for _, p := range patterns {
|
||||
if !ValidateIPPattern(p) {
|
||||
invalid = append(invalid, p)
|
||||
}
|
||||
}
|
||||
return invalid
|
||||
}
|
||||
@@ -675,6 +675,40 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||
now := time.Now().UTC()
|
||||
payload := map[string]string{
|
||||
"rate_limited_at": now.Format(time.RFC3339),
|
||||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
path := "{antigravity_quota_scopes," + string(scope) + "}"
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
||||
path,
|
||||
raw,
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
_, err := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -718,6 +752,27 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'antigravity_quota_scopes', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
builder := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -831,6 +886,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
args = append(args, *updates.Status)
|
||||
idx++
|
||||
}
|
||||
if updates.Schedulable != nil {
|
||||
setClauses = append(setClauses, "schedulable = $"+itoa(idx))
|
||||
args = append(args, *updates.Schedulable)
|
||||
idx++
|
||||
}
|
||||
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
|
||||
if len(updates.Credentials) > 0 {
|
||||
payload, err := json.Marshal(updates.Credentials)
|
||||
|
||||
@@ -26,13 +26,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
||||
created, err := r.client.APIKey.Create().
|
||||
builder := r.client.APIKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID).
|
||||
Save(ctx)
|
||||
SetNillableGroupID(key.GroupID)
|
||||
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
}
|
||||
if len(key.IPBlacklist) > 0 {
|
||||
builder.SetIPBlacklist(key.IPBlacklist)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
key.ID = created.ID
|
||||
key.CreatedAt = created.CreatedAt
|
||||
@@ -108,6 +116,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
builder.ClearGroupID()
|
||||
}
|
||||
|
||||
// IP 限制字段
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
} else {
|
||||
builder.ClearIPWhitelist()
|
||||
}
|
||||
if len(key.IPBlacklist) > 0 {
|
||||
builder.SetIPBlacklist(key.IPBlacklist)
|
||||
} else {
|
||||
builder.ClearIPBlacklist()
|
||||
}
|
||||
|
||||
affected, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -268,14 +288,16 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
return nil
|
||||
}
|
||||
out := &service.APIKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
@@ -317,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
|
||||
@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
out, err := r.GetByIDLite(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
// AccountCount is intentionally not loaded here; use GetByID when needed.
|
||||
m, err := r.client.Group.Query().
|
||||
Where(group.IDEQ(id)).
|
||||
Only(ctx)
|
||||
@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
|
||||
out := groupEntityToService(m)
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
return groupEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
||||
@@ -112,10 +120,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||
return r.ListWithFilters(ctx, params, "", "", "", nil)
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
q := r.client.Group.Query()
|
||||
|
||||
if platform != "" {
|
||||
@@ -124,6 +132,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
if status != "" {
|
||||
q = q.Where(group.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(group.Or(
|
||||
group.NameContainsFold(search),
|
||||
group.DescriptionContainsFold(search),
|
||||
))
|
||||
}
|
||||
if isExclusive != nil {
|
||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
|
||||
repo *groupRepository
|
||||
}
|
||||
|
||||
type forbidSQLExecutor struct {
|
||||
called bool
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql exec")
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql query")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
|
||||
group := &service.Group{
|
||||
Name: "lite-group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
spy := &forbidSQLExecutor{}
|
||||
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
|
||||
|
||||
got, err := repo.GetByIDLite(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(group.ID, got.ID)
|
||||
s.Require().False(spy.called, "expected no direct sql executor usage")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdate() {
|
||||
group := &service.Group{
|
||||
Name: "original",
|
||||
@@ -131,6 +167,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
pagination.PaginationParams{Page: 1, PageSize: 10},
|
||||
service.PlatformOpenAI,
|
||||
"",
|
||||
"",
|
||||
nil,
|
||||
)
|
||||
s.Require().NoError(err, "ListWithFilters base")
|
||||
@@ -152,7 +189,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, len(baseGroups)+1)
|
||||
// Verify all groups are OpenAI platform
|
||||
@@ -179,7 +216,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(service.StatusDisabled, groups[0].Status)
|
||||
@@ -204,12 +241,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
}))
|
||||
|
||||
isExclusive := true
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().True(groups[0].IsExclusive)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Search() {
|
||||
newRepo := func() (*groupRepository, context.Context) {
|
||||
tx := testEntTx(s.T())
|
||||
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
|
||||
}
|
||||
|
||||
containsID := func(groups []service.Group, id int64) bool {
|
||||
for i := range groups {
|
||||
if groups[i].ID == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
|
||||
s.Require().NoError(repo.Create(ctx, g))
|
||||
s.Require().NotZero(g.ID)
|
||||
return g
|
||||
}
|
||||
|
||||
newGroup := func(name string) *service.Group {
|
||||
return &service.Group{
|
||||
Name: name,
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
}
|
||||
|
||||
s.Run("search_name_should_match", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
|
||||
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_description_should_match", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := newGroup("it-group-search-desc-target")
|
||||
target.Description = "something about desc-needle in here"
|
||||
target = mustCreate(repo, ctx, target)
|
||||
|
||||
other := newGroup("it-group-search-desc-other")
|
||||
other.Description = "nothing to see here"
|
||||
other = mustCreate(repo, ctx, other)
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_nonexistent_should_return_empty", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
|
||||
|
||||
search := s.T().Name() + "__no_such_group__"
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(groups)
|
||||
})
|
||||
|
||||
s.Run("search_should_be_case_insensitive", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
|
||||
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_should_escape_like_wildcards", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
|
||||
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
|
||||
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
|
||||
|
||||
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
|
||||
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
|
||||
|
||||
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
|
||||
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
g1 := &service.Group{
|
||||
Name: "g1",
|
||||
@@ -244,7 +386,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
s.Require().NoError(err)
|
||||
|
||||
isExclusive := true
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(groups, 1)
|
||||
|
||||
273
backend/internal/repository/promo_code_repo.go
Normal file
273
backend/internal/repository/promo_code_repo.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type promoCodeRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
|
||||
return &promoCodeRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.Create().
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
code.ID = created.ID
|
||||
code.CreatedAt = created.CreatedAt
|
||||
code.UpdatedAt = created.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
m, err := client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
ForUpdate().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrPromoCodeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
code.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "")
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCode.Query()
|
||||
|
||||
if status != "" {
|
||||
q = q.Where(promocode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(promocode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocode.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outCodes := promoCodeEntitiesToService(codes)
|
||||
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
created, err := client.PromoCodeUsage.Create().
|
||||
SetPromoCodeID(usage.PromoCodeID).
|
||||
SetUserID(usage.UserID).
|
||||
SetBonusAmount(usage.BonusAmount).
|
||||
SetUsedAt(usage.UsedAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usage.ID = created.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
|
||||
m, err := r.client.PromoCodeUsage.Query().
|
||||
Where(
|
||||
promocodeusage.PromoCodeIDEQ(promoCodeID),
|
||||
promocodeusage.UserIDEQ(userID),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeUsageEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCodeUsage.Query().
|
||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
usages, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocodeusage.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outUsages := promoCodeUsageEntitiesToService(usages)
|
||||
|
||||
return outUsages, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.UpdateOneID(id).
|
||||
AddUsedCount(1).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Entity to Service conversions
|
||||
|
||||
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.PromoCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
BonusAmount: m.BonusAmount,
|
||||
MaxUses: m.MaxUses,
|
||||
UsedCount: m.UsedCount,
|
||||
Status: m.Status,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
|
||||
out := make([]service.PromoCode, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.PromoCodeUsage{
|
||||
ID: m.ID,
|
||||
PromoCodeID: m.PromoCodeID,
|
||||
UserID: m.UserID,
|
||||
BonusAmount: m.BonusAmount,
|
||||
UsedAt: m.UsedAt,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
|
||||
out := make([]service.PromoCodeUsage, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeUsageEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
@@ -110,6 +110,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
created_at
|
||||
@@ -119,7 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -130,6 +131,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration := nullInt(log.DurationMs)
|
||||
firstToken := nullInt(log.FirstTokenMs)
|
||||
userAgent := nullString(log.UserAgent)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -163,6 +165,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration,
|
||||
firstToken,
|
||||
userAgent,
|
||||
ipAddress,
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
createdAt,
|
||||
@@ -1873,6 +1876,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
durationMs sql.NullInt64
|
||||
firstTokenMs sql.NullInt64
|
||||
userAgent sql.NullString
|
||||
ipAddress sql.NullString
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
createdAt time.Time
|
||||
@@ -1905,6 +1909,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&durationMs,
|
||||
&firstTokenMs,
|
||||
&userAgent,
|
||||
&ipAddress,
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&createdAt,
|
||||
@@ -1959,6 +1964,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if userAgent.Valid {
|
||||
log.UserAgent = &userAgent.String
|
||||
}
|
||||
if ipAddress.Valid {
|
||||
log.IPAddress = &ipAddress.String
|
||||
}
|
||||
if imageSize.Valid {
|
||||
log.ImageSize = &imageSize.String
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountRepository,
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewUsageLogRepository,
|
||||
NewSettingRepository,
|
||||
NewUserSubscriptionRepository,
|
||||
|
||||
@@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) {
|
||||
"name": "Key One",
|
||||
"group_id": null,
|
||||
"status": "active",
|
||||
"ip_whitelist": null,
|
||||
"ip_blacklist": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
@@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) {
|
||||
"name": "Key One",
|
||||
"group_id": null,
|
||||
"status": "active",
|
||||
"ip_whitelist": null,
|
||||
"ip_blacklist": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
@@ -304,6 +308,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"turnstile_enabled": true,
|
||||
"turnstile_site_key": "site-key",
|
||||
"turnstile_secret_key_configured": true,
|
||||
"linuxdo_connect_enabled": false,
|
||||
"linuxdo_connect_client_id": "",
|
||||
"linuxdo_connect_client_secret_configured": false,
|
||||
"linuxdo_connect_redirect_url": "",
|
||||
"site_name": "Sub2API",
|
||||
"site_logo": "",
|
||||
"site_subtitle": "Subtitle",
|
||||
@@ -390,7 +398,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
|
||||
@@ -567,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -583,7 +595,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ProviderSet 提供服务器层的依赖
|
||||
@@ -30,6 +31,7 @@ func ProvideRouter(
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -47,7 +49,7 @@ func ProvideRouter(
|
||||
}
|
||||
}
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -71,6 +74,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 IP 限制(白名单/黑名单)
|
||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||
clientIP := ip.GetClientIP(c)
|
||||
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||
if !allowed {
|
||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 检查关联的用户
|
||||
if apiKey.User == nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
|
||||
@@ -91,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -149,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
@@ -173,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
|
||||
subscription, ok := value.(*service.UserSubscription)
|
||||
return subscription, ok
|
||||
}
|
||||
|
||||
func setGroupContext(c *gin.Context, group *service.Group) {
|
||||
if !service.IsGroupContextValid(group) {
|
||||
return
|
||||
}
|
||||
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) {
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
|
||||
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 99,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformGemini,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyService := service.NewAPIKeyService(
|
||||
fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{RunMode: config.RunModeSimple},
|
||||
)
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r := gin.New()
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
ID: 42,
|
||||
Name: "sub",
|
||||
Status: service.StatusActive,
|
||||
Hydrated: true,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 101,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 101,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
|
||||
invalidGroup := &service.Group{
|
||||
ID: group.ID,
|
||||
Platform: group.Platform,
|
||||
Status: group.Status,
|
||||
}
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup))
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// SetupRouter 配置路由器中间件和路由
|
||||
@@ -21,6 +22,7 @@ func SetupRouter(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware2.Logger())
|
||||
@@ -33,7 +35,7 @@ func SetupRouter(
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -48,6 +50,7 @@ func registerRoutes(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
// 通用路由(健康检查、状态等)
|
||||
routes.RegisterCommonRoutes(r)
|
||||
@@ -56,7 +59,7 @@ func registerRoutes(
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
|
||||
@@ -44,6 +44,9 @@ func RegisterAdminRoutes(
|
||||
// 卡密管理
|
||||
registerRedeemCodeRoutes(admin, h)
|
||||
|
||||
// 优惠码管理
|
||||
registerPromoCodeRoutes(admin, h)
|
||||
|
||||
// 系统设置
|
||||
registerSettingsRoutes(admin, h)
|
||||
|
||||
@@ -201,6 +204,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerPromoCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
promoCodes := admin.Group("/promo-codes")
|
||||
{
|
||||
promoCodes.GET("", h.Admin.Promo.List)
|
||||
promoCodes.GET("/:id", h.Admin.Promo.GetByID)
|
||||
promoCodes.POST("", h.Admin.Promo.Create)
|
||||
promoCodes.PUT("/:id", h.Admin.Promo.Update)
|
||||
promoCodes.DELETE("/:id", h.Admin.Promo.Delete)
|
||||
promoCodes.GET("/:id/usages", h.Admin.Promo.GetUsages)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
adminSettings := admin.Group("/settings")
|
||||
{
|
||||
|
||||
@@ -1,24 +1,36 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RegisterAuthRoutes 注册认证相关路由
|
||||
func RegisterAuthRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
// 创建速率限制器
|
||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次
|
||||
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
|
||||
@@ -49,10 +49,12 @@ type AccountRepository interface {
|
||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
@@ -66,6 +68,7 @@ type AccountBulkUpdate struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
@@ -139,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
|
||||
panic("unexpected SetRateLimited call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
panic("unexpected SetAntigravityQuotaScopeLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
panic("unexpected SetOverloaded call")
|
||||
}
|
||||
@@ -155,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearRateLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearAntigravityQuotaScopes call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
panic("unexpected UpdateSessionWindow call")
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ type AdminService interface {
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
// Group management
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
|
||||
GetAllGroups(ctx context.Context) ([]Group, error)
|
||||
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||
GetGroup(ctx context.Context, id int64) (*Group, error)
|
||||
@@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status string
|
||||
Schedulable *bool
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
@@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
// Group management implementations
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -575,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
|
||||
return fmt.Errorf("cannot set self as fallback group")
|
||||
}
|
||||
|
||||
// 检查降级分组是否存在
|
||||
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
visited := map[int64]struct{}{}
|
||||
nextID := fallbackGroupID
|
||||
for {
|
||||
if _, seen := visited[nextID]; seen {
|
||||
return fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
visited[nextID] = struct{}{}
|
||||
if currentGroupID > 0 && nextID == currentGroupID {
|
||||
return fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
|
||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||
if fallbackGroup.ClaudeCodeOnly {
|
||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||
}
|
||||
// 检查降级分组是否存在
|
||||
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
|
||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||
}
|
||||
|
||||
if fallbackGroup.FallbackGroupID == nil {
|
||||
return nil
|
||||
}
|
||||
nextID = *fallbackGroup.FallbackGroupID
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
@@ -910,6 +926,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if input.Status != "" {
|
||||
repoUpdates.Status = &input.Status
|
||||
}
|
||||
if input.Schedulable != nil {
|
||||
repoUpdates.Schedulable = input.Schedulable
|
||||
}
|
||||
|
||||
// Run bulk update for column/jsonb fields first.
|
||||
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
|
||||
|
||||
@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByIDLite call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
@@ -124,7 +128,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct {
|
||||
updated *Group // 记录 Update 调用的参数
|
||||
getByID *Group // GetByID 返回值
|
||||
getErr error // GetByID 返回的错误
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersPlatform string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersIsExclusive *bool
|
||||
listWithFiltersGroups []Group
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
|
||||
@@ -35,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
@@ -47,8 +64,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
s.listWithFiltersIsExclusive = isExclusive
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersGroups)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersGroups, result, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
|
||||
@@ -195,3 +232,149 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
|
||||
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
|
||||
// 测试:
|
||||
// 1. search 参数正常传递到 repository 层
|
||||
// 2. search 为空字符串时的行为
|
||||
// 3. search 与其他过滤条件组合使用
|
||||
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "alpha", repo.listWithFiltersSearch)
|
||||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
|
||||
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, groups)
|
||||
require.Equal(t, int64(0), total)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "", repo.listWithFiltersSearch)
|
||||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
|
||||
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
|
||||
isExclusive := true
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), total)
|
||||
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "beta", repo.listWithFiltersSearch)
|
||||
require.NotNil(t, repo.listWithFiltersIsExclusive)
|
||||
require.True(t, *repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
fallbackID := int64(2)
|
||||
repo := &groupRepoStubForFallbackCycle{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
FallbackGroupID: &fallbackID,
|
||||
},
|
||||
fallbackID: {
|
||||
ID: fallbackID,
|
||||
FallbackGroupID: &groupID,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group cycle")
|
||||
}
|
||||
|
||||
type groupRepoStubForFallbackCycle struct {
|
||||
groups map[int64]*Group
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.GetByIDLite(ctx, id)
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := s.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
238
backend/internal/service/admin_service_search_test.go
Normal file
238
backend/internal/service/admin_service_search_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type accountRepoStubForAdminList struct {
|
||||
accountRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersPlatform string
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersAccounts []Account
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersType = accountType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersAccounts)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersAccounts, result, nil
|
||||
}
|
||||
|
||||
type proxyRepoStubForAdminList struct {
|
||||
proxyRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersProtocol string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersProxies []Proxy
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
|
||||
listWithFiltersAndAccountCountCalls int
|
||||
listWithFiltersAndAccountCountParams pagination.PaginationParams
|
||||
listWithFiltersAndAccountCountProtocol string
|
||||
listWithFiltersAndAccountCountStatus string
|
||||
listWithFiltersAndAccountCountSearch string
|
||||
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
|
||||
listWithFiltersAndAccountCountResult *pagination.PaginationResult
|
||||
listWithFiltersAndAccountCountErr error
|
||||
}
|
||||
|
||||
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersProtocol = protocol
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersProxies)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersProxies, result, nil
|
||||
}
|
||||
|
||||
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersAndAccountCountCalls++
|
||||
s.listWithFiltersAndAccountCountParams = params
|
||||
s.listWithFiltersAndAccountCountProtocol = protocol
|
||||
s.listWithFiltersAndAccountCountStatus = status
|
||||
s.listWithFiltersAndAccountCountSearch = search
|
||||
|
||||
if s.listWithFiltersAndAccountCountErr != nil {
|
||||
return nil, nil, s.listWithFiltersAndAccountCountErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersAndAccountCountResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersAndAccountCountProxies, result, nil
|
||||
}
|
||||
|
||||
type redeemRepoStubForAdminList struct {
|
||||
redeemRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersCodes []RedeemCode
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersType = codeType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersCodes)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersCodes, result, nil
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &accountRepoStubForAdminList{
|
||||
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), total)
|
||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
|
||||
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "acc", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
|
||||
}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(7), total)
|
||||
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "http", repo.listWithFiltersProtocol)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "p1", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
|
||||
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
|
||||
}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(9), total)
|
||||
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
|
||||
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
|
||||
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
|
||||
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &redeemRepoStubForAdminList{
|
||||
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
|
||||
}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), total)
|
||||
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
|
||||
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "ABC", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
@@ -93,6 +93,7 @@ var antigravityPrefixMapping = []struct {
|
||||
// 长前缀优先
|
||||
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
|
||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
|
||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
||||
@@ -502,6 +503,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
originalModel := claudeReq.Model
|
||||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||||
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -603,7 +605,7 @@ urlFallbackLoop:
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||
}
|
||||
// 最后一次尝试也失败
|
||||
resp = &http.Response{
|
||||
@@ -696,7 +698,7 @@ urlFallbackLoop:
|
||||
|
||||
// 处理错误响应(重试后仍失败或不触发重试)
|
||||
if resp.StatusCode >= 400 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
@@ -1021,6 +1023,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if len(body) == 0 {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||||
}
|
||||
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
||||
|
||||
// 解析请求以获取 image_size(用于图片计费)
|
||||
imageSize := s.extractImageSize(body)
|
||||
@@ -1146,7 +1149,7 @@ urlFallbackLoop:
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
@@ -1200,7 +1203,7 @@ urlFallbackLoop:
|
||||
goto handleSuccess
|
||||
}
|
||||
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
@@ -1314,7 +1317,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
||||
if statusCode == 429 {
|
||||
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||
@@ -1325,13 +1328,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
|
||||
defaultDur = 5 * time.Minute
|
||||
}
|
||||
ra := time.Now().Add(defaultDur)
|
||||
log.Printf("%s status=429 rate_limited reset_in=%v (fallback)", prefix, defaultDur)
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
||||
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
|
||||
if quotaScope == "" {
|
||||
return
|
||||
}
|
||||
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
|
||||
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
log.Printf("%s status=429 rate_limited reset_at=%v reset_in=%v", prefix, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
|
||||
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
|
||||
if quotaScope == "" {
|
||||
return
|
||||
}
|
||||
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
|
||||
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
// 其他错误码继续使用 rateLimitService
|
||||
|
||||
88
backend/internal/service/antigravity_quota_scope.go
Normal file
88
backend/internal/service/antigravity_quota_scope.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
|
||||
|
||||
// AntigravityQuotaScope 表示 Antigravity 的配额域
|
||||
type AntigravityQuotaScope string
|
||||
|
||||
const (
|
||||
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
|
||||
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
|
||||
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
|
||||
)
|
||||
|
||||
// resolveAntigravityQuotaScope 根据模型名称解析配额域
|
||||
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||
model := normalizeAntigravityModelName(requestedModel)
|
||||
if model == "" {
|
||||
return "", false
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(model, "claude-"):
|
||||
return AntigravityQuotaScopeClaude, true
|
||||
case strings.HasPrefix(model, "gemini-"):
|
||||
if isImageGenerationModel(model) {
|
||||
return AntigravityQuotaScopeGeminiImage, true
|
||||
}
|
||||
return AntigravityQuotaScopeGeminiText, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAntigravityModelName(model string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(model))
|
||||
normalized = strings.TrimPrefix(normalized, "models/")
|
||||
return normalized
|
||||
}
|
||||
|
||||
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
|
||||
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
if !a.IsSchedulable() {
|
||||
return false
|
||||
}
|
||||
if a.Platform != PlatformAntigravity {
|
||||
return true
|
||||
}
|
||||
scope, ok := resolveAntigravityQuotaScope(requestedModel)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||
if resetAt == nil {
|
||||
return true
|
||||
}
|
||||
now := time.Now()
|
||||
return !now.Before(*resetAt)
|
||||
}
|
||||
|
||||
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
|
||||
if a == nil || a.Extra == nil || scope == "" {
|
||||
return nil
|
||||
}
|
||||
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
rawScope, ok := rawScopes[string(scope)].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
|
||||
if !ok || strings.TrimSpace(resetAtRaw) == "" {
|
||||
return nil
|
||||
}
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &resetAt
|
||||
}
|
||||
@@ -3,16 +3,18 @@ package service
|
||||
import "time"
|
||||
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
Name string
|
||||
GroupID *int64
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
Name string
|
||||
GroupID *int64
|
||||
Status string
|
||||
IPWhitelist []string
|
||||
IPBlacklist []string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (k *APIKey) IsActive() bool {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
)
|
||||
@@ -20,6 +21,7 @@ var (
|
||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -57,16 +59,20 @@ type APIKeyCache interface {
|
||||
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
|
||||
}
|
||||
|
||||
// APIKeyService API Key服务
|
||||
@@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 验证 IP 白名单格式
|
||||
if len(req.IPWhitelist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 IP 黑名单格式
|
||||
if len(req.IPBlacklist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证分组权限(如果指定了分组)
|
||||
if req.GroupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
@@ -236,11 +256,13 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: StatusActive,
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: StatusActive,
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||
@@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
return nil, ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 验证 IP 白名单格式
|
||||
if len(req.IPWhitelist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 IP 黑名单格式
|
||||
if len(req.IPBlacklist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
apiKey.Name = *req.Name
|
||||
@@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 IP 限制(空数组会清空设置)
|
||||
apiKey.IPWhitelist = req.IPWhitelist
|
||||
apiKey.IPBlacklist = req.IPBlacklist
|
||||
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -18,6 +22,7 @@ var (
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
@@ -47,6 +52,7 @@ type AuthService struct {
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
@@ -57,6 +63,7 @@ func NewAuthService(
|
||||
emailService *EmailService,
|
||||
turnstileService *TurnstileService,
|
||||
emailQueueService *EmailQueueService,
|
||||
promoService *PromoService,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
@@ -65,21 +72,27 @@ func NewAuthService(
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "")
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
// RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
|
||||
if isReservedEmail(email) {
|
||||
return "", nil, ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查是否需要邮件验证
|
||||
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
||||
@@ -132,10 +145,27 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
// 优先检查邮箱冲突错误(竞态条件下可能发生)
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return "", nil, ErrEmailExists
|
||||
}
|
||||
log.Printf("[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供)
|
||||
if promoCode != "" && s.promoService != nil {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
// 优惠码应用失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
} else {
|
||||
// 重新获取用户信息以获取更新后的余额
|
||||
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
|
||||
user = updatedUser
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 生成token
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
@@ -152,11 +182,15 @@ type SendVerifyCodeResult struct {
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码(同步方式)
|
||||
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
// 检查是否开放注册(默认关闭)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return ErrRegDisabled
|
||||
}
|
||||
|
||||
if isReservedEmail(email) {
|
||||
return ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
@@ -185,12 +219,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
||||
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
|
||||
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
// 检查是否开放注册(默认关闭)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
log.Println("[Auth] Registration is disabled")
|
||||
return nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
if isReservedEmail(email) {
|
||||
return nil, ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
@@ -270,7 +308,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return true
|
||||
return false // 安全默认:settingService 未配置时关闭注册
|
||||
}
|
||||
return s.settingService.IsRegistrationEnabled(ctx)
|
||||
}
|
||||
@@ -315,6 +353,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
|
||||
// - 如果邮箱已存在:直接登录(不需要本地密码)
|
||||
// - 如果邮箱不存在:创建新用户并登录
|
||||
//
|
||||
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
|
||||
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
|
||||
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" || len(email) > 255 {
|
||||
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
|
||||
username = strings.TrimSpace(username)
|
||||
if len([]rune(username)) > 100 {
|
||||
username = string([]rune(username)[:100])
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// OAuth 首次登录视为注册。
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
hashedPassword, err := s.HashPassword(randomPassword)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// 新用户默认值。
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
Email: email,
|
||||
Username: username,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
// 并发场景:GetByEmail 与 Create 之间用户被创建。
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error getting user after conflict: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error creating oauth user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error during oauth login: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
}
|
||||
|
||||
if !user.IsActive() {
|
||||
return "", nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
|
||||
if user.Username == "" && username != "" {
|
||||
user.Username = username
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
@@ -357,6 +491,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
func randomHexString(byteLength int) (string, error) {
|
||||
if byteLength <= 0 {
|
||||
byteLength = 16
|
||||
}
|
||||
buf := make([]byte, byteLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func isReservedEmail(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
emailService,
|
||||
nil,
|
||||
nil,
|
||||
nil, // promoService
|
||||
)
|
||||
}
|
||||
|
||||
@@ -113,6 +114,15 @@ func TestAuthService_Register_Disabled(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
|
||||
// 当 settings 为 nil(设置项不存在)时,注册应该默认关闭
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
// 邮件验证开启但 emailCache 为 nil(emailService 未配置)
|
||||
@@ -122,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
@@ -134,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
|
||||
@@ -148,14 +158,16 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
|
||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||
require.ErrorContains(t, err, "verify code")
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailExists(t *testing.T) {
|
||||
repo := &userRepoStub{exists: true}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
@@ -163,23 +175,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
|
||||
|
||||
func TestAuthService_Register_CheckEmailError(t *testing.T) {
|
||||
repo := &userRepoStub{existsErr: errors.New("db down")}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_ReservedEmail(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
|
||||
require.ErrorIs(t, err, ErrEmailReserved)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateError(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: errors.New("create failed")}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
|
||||
// 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败
|
||||
repo := &userRepoStub{createErr: ErrEmailExists}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_Success(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 5}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
token, user, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -38,6 +38,12 @@ const (
|
||||
RedeemTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
const (
|
||||
PromoCodeStatusActive = "active"
|
||||
PromoCodeStatusDisabled = "disabled"
|
||||
)
|
||||
|
||||
// Admin adjustment type constants
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
@@ -105,7 +111,17 @@ const (
|
||||
// Request identity patch (Claude -> Gemini systemInstruction injection)
|
||||
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
||||
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
)
|
||||
|
||||
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
|
||||
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
data.Attempts++
|
||||
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
log.Printf("[Email] Failed to update verification attempt count: %v", err)
|
||||
}
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
_ = s.cache.DeleteVerificationCode(ctx, email)
|
||||
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
|
||||
log.Printf("[Email] Failed to delete verification code after success: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -23,9 +24,11 @@ type mockAccountRepoForPlatform struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
listPlatformFunc func(ctx context.Context, platform string) ([]Account, error)
|
||||
getByIDCalls int
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
m.getByIDCalls++
|
||||
if acc, ok := m.accountsByID[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
@@ -136,6 +139,9 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
|
||||
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -148,6 +154,9 @@ func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context,
|
||||
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -185,6 +194,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockGroupRepoForGateway struct {
|
||||
groups map[int64]*Group
|
||||
getByIDCalls int
|
||||
getByIDLiteCalls int
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDLiteCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) Create(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) Update(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
@@ -891,6 +950,74 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc
|
||||
return m.accountWaitCounts[accountID], nil
|
||||
}
|
||||
|
||||
type mockConcurrencyCache struct {
|
||||
acquireAccountCalls int
|
||||
loadBatchCalls int
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
m.acquireAccountCalls++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
m.loadBatchCalls++
|
||||
result := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
result[acc.ID] = &AccountLoadInfo{
|
||||
AccountID: acc.ID,
|
||||
CurrentConcurrency: 0,
|
||||
WaitingCount: 0,
|
||||
LoadRate: 0,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
|
||||
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@@ -989,6 +1116,78 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
|
||||
})
|
||||
|
||||
t.Run("粘性命中-不调用GetByID", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"sticky": 1},
|
||||
}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
concurrencyCache := &mockConcurrencyCache{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(1), result.Account.ID)
|
||||
require.Equal(t, 0, repo.getByIDCalls, "粘性命中不应调用GetByID")
|
||||
require.Equal(t, 0, concurrencyCache.loadBatchCalls, "粘性命中应在负载批量查询前返回")
|
||||
})
|
||||
|
||||
t.Run("粘性账号不在候选集-回退负载感知选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"sticky": 1},
|
||||
}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
concurrencyCache := &mockConcurrencyCache{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "粘性账号不在候选集时应回退到可用账号")
|
||||
require.Equal(t, 0, repo.getByIDCalls, "粘性账号缺失不应回退到GetByID")
|
||||
require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询")
|
||||
})
|
||||
|
||||
t.Run("无可用账号-返回错误", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{},
|
||||
@@ -1013,3 +1212,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(42)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{groupID: group},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(42)
|
||||
ctxGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{groupID: group},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) {
|
||||
groupID := int64(42)
|
||||
invalidGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
hydratedGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup)
|
||||
svc := &GatewayService{}
|
||||
ctx = svc.withGroupContext(ctx, hydratedGroup)
|
||||
|
||||
got, ok := ctx.Value(ctxkey.Group).(*Group)
|
||||
require.True(t, ok)
|
||||
require.Same(t, hydratedGroup, got)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10)
|
||||
fallbackID := int64(11)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &fallbackID,
|
||||
Hydrated: true,
|
||||
}
|
||||
fallbackGroup := &Group{
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{fallbackID: fallbackGroup},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10)
|
||||
fallbackID := int64(11)
|
||||
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &fallbackID,
|
||||
}
|
||||
fallbackGroup := &Group{
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &groupID,
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{
|
||||
groupID: group,
|
||||
fallbackID: fallbackGroup,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
|
||||
gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gotGroup)
|
||||
require.Nil(t, gotID)
|
||||
require.Contains(t, err.Error(), "fallback group cycle")
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
defaultMaxLineSize = 10 * 1024 * 1024
|
||||
defaultMaxLineSize = 40 * 1024 * 1024
|
||||
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
||||
)
|
||||
@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
groupID = resolvedGroupID
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
platform = group.Platform
|
||||
|
||||
// 检查 Claude Code 客户端限制
|
||||
if group.ClaudeCodeOnly {
|
||||
isClaudeCode := IsClaudeCodeClient(ctx)
|
||||
if !isClaudeCode {
|
||||
// 非 Claude Code 客户端,检查是否有降级分组
|
||||
if group.FallbackGroupID != nil {
|
||||
// 使用降级分组重新调度
|
||||
fallbackGroupID := *group.FallbackGroupID
|
||||
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
|
||||
}
|
||||
// 无降级分组,拒绝访问
|
||||
return nil, ErrClaudeCodeOnly
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 无分组时只使用原生 anthropic 平台
|
||||
platform = PlatformAnthropic
|
||||
@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
|
||||
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
|
||||
group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}, nil
|
||||
}
|
||||
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -478,10 +465,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) &&
|
||||
// 粘性命中仅在当前可调度候选集中生效。
|
||||
accountByID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
accountByID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
account, ok := accountByID[accountID]
|
||||
if ok && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulable() &&
|
||||
account.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
@@ -519,6 +511,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -652,51 +647,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context {
|
||||
if !IsGroupContextValid(group) {
|
||||
return ctx
|
||||
}
|
||||
if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, ctxkey.Group, group)
|
||||
}
|
||||
|
||||
func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group {
|
||||
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID {
|
||||
return group
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
|
||||
if group := s.groupFromContext(ctx, groupID); group != nil {
|
||||
return group, nil
|
||||
}
|
||||
group, err := s.groupRepo.GetByIDLite(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
|
||||
if groupID == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
currentID := *groupID
|
||||
visited := map[int64]struct{}{}
|
||||
for {
|
||||
if _, seen := visited[currentID]; seen {
|
||||
return nil, nil, fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
visited[currentID] = struct{}{}
|
||||
|
||||
group, err := s.resolveGroupByID(ctx, currentID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) {
|
||||
return group, ¤tID, nil
|
||||
}
|
||||
|
||||
if group.FallbackGroupID == nil {
|
||||
return nil, nil, ErrClaudeCodeOnly
|
||||
}
|
||||
currentID = *group.FallbackGroupID
|
||||
}
|
||||
}
|
||||
|
||||
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
|
||||
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
|
||||
// - 有降级分组:返回降级分组的 ID
|
||||
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
|
||||
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
|
||||
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) {
|
||||
if groupID == nil {
|
||||
return groupID, nil
|
||||
return nil, groupID, nil
|
||||
}
|
||||
|
||||
// 强制平台模式不检查 Claude Code 限制
|
||||
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
|
||||
return groupID, nil
|
||||
return nil, groupID, nil
|
||||
}
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !group.ClaudeCodeOnly {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 分组启用了 Claude Code 限制
|
||||
if IsClaudeCodeClient(ctx) {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 非 Claude Code 客户端,检查降级分组
|
||||
if group.FallbackGroupID != nil {
|
||||
return group.FallbackGroupID, nil
|
||||
}
|
||||
|
||||
return nil, ErrClaudeCodeOnly
|
||||
return group, resolvedID, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
|
||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) {
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
return forcePlatform, true, nil
|
||||
}
|
||||
if group != nil {
|
||||
return group.Platform, false, nil
|
||||
}
|
||||
if groupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
group, err := s.resolveGroupByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("get group failed: %w", err)
|
||||
return "", false, err
|
||||
}
|
||||
return group.Platform, false, nil
|
||||
}
|
||||
@@ -812,7 +853,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
@@ -844,6 +885,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -901,7 +945,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
@@ -936,6 +980,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -2247,6 +2294,7 @@ type RecordUsageInput struct {
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
@@ -2337,6 +2385,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
}
|
||||
|
||||
// 添加 IPAddress
|
||||
if input.IPAddress != "" {
|
||||
usageLog.IPAddress = &input.IPAddress
|
||||
}
|
||||
|
||||
// 添加分组和订阅关联
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
|
||||
@@ -86,9 +86,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
var group *Group
|
||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||
group = ctxGroup
|
||||
} else {
|
||||
var err error
|
||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
}
|
||||
platform = group.Platform
|
||||
} else {
|
||||
@@ -114,7 +120,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
valid := false
|
||||
if account.Platform == platform {
|
||||
valid = true
|
||||
@@ -172,6 +178,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -121,6 +122,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -131,6 +135,9 @@ func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, i
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -146,10 +153,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
|
||||
|
||||
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
|
||||
type mockGroupRepoForGemini struct {
|
||||
groups map[int64]*Group
|
||||
groups map[int64]*Group
|
||||
getByIDCalls int
|
||||
getByIDLiteCalls int
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDLiteCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
@@ -166,7 +184,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([
|
||||
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
|
||||
@@ -242,6 +260,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
|
||||
require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformGemini,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {ID: groupID, Platform: PlatformGemini},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -10,6 +10,7 @@ type Group struct {
|
||||
RateMultiplier float64
|
||||
IsExclusive bool
|
||||
Status string
|
||||
Hydrated bool // indicates the group was loaded from a trusted repository source
|
||||
|
||||
SubscriptionType string
|
||||
DailyLimitUSD *float64
|
||||
@@ -72,3 +73,20 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
|
||||
return g.ImagePrice2K
|
||||
}
|
||||
}
|
||||
|
||||
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
|
||||
func IsGroupContextValid(group *Group) bool {
|
||||
if group == nil {
|
||||
return false
|
||||
}
|
||||
if group.ID <= 0 {
|
||||
return false
|
||||
}
|
||||
if !group.Hydrated {
|
||||
return false
|
||||
}
|
||||
if group.Platform == "" || group.Status == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -16,12 +16,13 @@ var (
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, group *Group) error
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
GetByIDLite(ctx context.Context, id int64) (*Group, error)
|
||||
Update(ctx context.Context, group *Group) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]Group, error)
|
||||
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||
|
||||
|
||||
@@ -455,6 +455,10 @@ func getOpenCodeCodexHeader() string {
|
||||
return getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
||||
}
|
||||
|
||||
func GetOpenCodeInstructions() string {
|
||||
return getOpenCodeCodexHeader()
|
||||
}
|
||||
|
||||
func filterCodexInput(input []any) []any {
|
||||
filtered := make([]any, 0, len(input))
|
||||
for _, item := range input {
|
||||
|
||||
@@ -1251,6 +1251,7 @@ type OpenAIRecordUsageInput struct {
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -1325,6 +1326,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
}
|
||||
|
||||
// 添加 IPAddress
|
||||
if input.IPAddress != "" {
|
||||
usageLog.IPAddress = &input.IPAddress
|
||||
}
|
||||
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
|
||||
73
backend/internal/service/promo_code.go
Normal file
73
backend/internal/service/promo_code.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
type PromoCode struct {
|
||||
ID int64
|
||||
Code string
|
||||
BonusAmount float64
|
||||
MaxUses int
|
||||
UsedCount int
|
||||
Status string
|
||||
ExpiresAt *time.Time
|
||||
Notes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// 关联
|
||||
UsageRecords []PromoCodeUsage
|
||||
}
|
||||
|
||||
// PromoCodeUsage 优惠码使用记录
|
||||
type PromoCodeUsage struct {
|
||||
ID int64
|
||||
PromoCodeID int64
|
||||
UserID int64
|
||||
BonusAmount float64
|
||||
UsedAt time.Time
|
||||
|
||||
// 关联
|
||||
PromoCode *PromoCode
|
||||
User *User
|
||||
}
|
||||
|
||||
// CanUse 检查优惠码是否可用
|
||||
func (p *PromoCode) CanUse() bool {
|
||||
if p.Status != PromoCodeStatusActive {
|
||||
return false
|
||||
}
|
||||
if p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if p.MaxUses > 0 && p.UsedCount >= p.MaxUses {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsExpired 检查是否已过期
|
||||
func (p *PromoCode) IsExpired() bool {
|
||||
return p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt)
|
||||
}
|
||||
|
||||
// CreatePromoCodeInput 创建优惠码输入
|
||||
type CreatePromoCodeInput struct {
|
||||
Code string
|
||||
BonusAmount float64
|
||||
MaxUses int
|
||||
ExpiresAt *time.Time
|
||||
Notes string
|
||||
}
|
||||
|
||||
// UpdatePromoCodeInput 更新优惠码输入
|
||||
type UpdatePromoCodeInput struct {
|
||||
Code *string
|
||||
BonusAmount *float64
|
||||
MaxUses *int
|
||||
Status *string
|
||||
ExpiresAt *time.Time
|
||||
Notes *string
|
||||
}
|
||||
30
backend/internal/service/promo_code_repository.go
Normal file
30
backend/internal/service/promo_code_repository.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
// PromoCodeRepository 优惠码仓储接口
|
||||
type PromoCodeRepository interface {
|
||||
// 基础 CRUD
|
||||
Create(ctx context.Context, code *PromoCode) error
|
||||
GetByID(ctx context.Context, id int64) (*PromoCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*PromoCode, error)
|
||||
GetByCodeForUpdate(ctx context.Context, code string) (*PromoCode, error) // 带行锁的查询,用于并发控制
|
||||
Update(ctx context.Context, code *PromoCode) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// 列表查询
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]PromoCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error)
|
||||
|
||||
// 使用记录
|
||||
CreateUsage(ctx context.Context, usage *PromoCodeUsage) error
|
||||
GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*PromoCodeUsage, error)
|
||||
ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error)
|
||||
|
||||
// 计数操作
|
||||
IncrementUsedCount(ctx context.Context, id int64) error
|
||||
}
|
||||
256
backend/internal/service/promo_service.go
Normal file
256
backend/internal/service/promo_service.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPromoCodeNotFound = infraerrors.NotFound("PROMO_CODE_NOT_FOUND", "promo code not found")
|
||||
ErrPromoCodeExpired = infraerrors.BadRequest("PROMO_CODE_EXPIRED", "promo code has expired")
|
||||
ErrPromoCodeDisabled = infraerrors.BadRequest("PROMO_CODE_DISABLED", "promo code is disabled")
|
||||
ErrPromoCodeMaxUsed = infraerrors.BadRequest("PROMO_CODE_MAX_USED", "promo code has reached maximum uses")
|
||||
ErrPromoCodeAlreadyUsed = infraerrors.Conflict("PROMO_CODE_ALREADY_USED", "you have already used this promo code")
|
||||
ErrPromoCodeInvalid = infraerrors.BadRequest("PROMO_CODE_INVALID", "invalid promo code")
|
||||
)
|
||||
|
||||
// PromoService 优惠码服务
|
||||
type PromoService struct {
|
||||
promoRepo PromoCodeRepository
|
||||
userRepo UserRepository
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
}
|
||||
|
||||
// NewPromoService 创建优惠码服务实例
|
||||
func NewPromoService(
|
||||
promoRepo PromoCodeRepository,
|
||||
userRepo UserRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
) *PromoService {
|
||||
return &PromoService{
|
||||
promoRepo: promoRepo,
|
||||
userRepo: userRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePromoCode 验证优惠码(注册前调用)
|
||||
// 返回 nil, nil 表示空码(不报错)
|
||||
func (s *PromoService) ValidatePromoCode(ctx context.Context, code string) (*PromoCode, error) {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return nil, nil // 空码不报错,直接返回
|
||||
}
|
||||
|
||||
promoCode, err := s.promoRepo.GetByCode(ctx, code)
|
||||
if err != nil {
|
||||
// 保留原始错误类型,不要统一映射为 NotFound
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// validatePromoCodeStatus 验证优惠码状态
|
||||
func (s *PromoService) validatePromoCodeStatus(promoCode *PromoCode) error {
|
||||
if !promoCode.CanUse() {
|
||||
if promoCode.IsExpired() {
|
||||
return ErrPromoCodeExpired
|
||||
}
|
||||
if promoCode.Status == PromoCodeStatusDisabled {
|
||||
return ErrPromoCodeDisabled
|
||||
}
|
||||
if promoCode.MaxUses > 0 && promoCode.UsedCount >= promoCode.MaxUses {
|
||||
return ErrPromoCodeMaxUsed
|
||||
}
|
||||
return ErrPromoCodeInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPromoCode 应用优惠码(注册成功后调用)
|
||||
// 使用事务和行锁确保并发安全
|
||||
func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code string) error {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 开启事务
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
// 在事务中获取并锁定优惠码记录(FOR UPDATE)
|
||||
promoCode, err := s.promoRepo.GetByCodeForUpdate(txCtx, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 在事务中验证优惠码状态
|
||||
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 在事务中检查用户是否已使用过此优惠码
|
||||
existing, err := s.promoRepo.GetUsageByPromoCodeAndUser(txCtx, promoCode.ID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing usage: %w", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return ErrPromoCodeAlreadyUsed
|
||||
}
|
||||
|
||||
// 增加用户余额
|
||||
if err := s.userRepo.UpdateBalance(txCtx, userID, promoCode.BonusAmount); err != nil {
|
||||
return fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
|
||||
// 创建使用记录
|
||||
usage := &PromoCodeUsage{
|
||||
PromoCodeID: promoCode.ID,
|
||||
UserID: userID,
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
UsedAt: time.Now(),
|
||||
}
|
||||
if err := s.promoRepo.CreateUsage(txCtx, usage); err != nil {
|
||||
return fmt.Errorf("create usage record: %w", err)
|
||||
}
|
||||
|
||||
// 增加使用次数
|
||||
if err := s.promoRepo.IncrementUsedCount(txCtx, promoCode.ID); err != nil {
|
||||
return fmt.Errorf("increment used count: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateRandomCode 生成随机优惠码
|
||||
func (s *PromoService) GenerateRandomCode() (string, error) {
|
||||
bytes := make([]byte, 8)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
return strings.ToUpper(hex.EncodeToString(bytes)), nil
|
||||
}
|
||||
|
||||
// Create 创建优惠码
|
||||
func (s *PromoService) Create(ctx context.Context, input *CreatePromoCodeInput) (*PromoCode, error) {
|
||||
code := strings.TrimSpace(input.Code)
|
||||
if code == "" {
|
||||
// 自动生成
|
||||
var err error
|
||||
code, err = s.GenerateRandomCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
promoCode := &PromoCode{
|
||||
Code: strings.ToUpper(code),
|
||||
BonusAmount: input.BonusAmount,
|
||||
MaxUses: input.MaxUses,
|
||||
UsedCount: 0,
|
||||
Status: PromoCodeStatusActive,
|
||||
ExpiresAt: input.ExpiresAt,
|
||||
Notes: input.Notes,
|
||||
}
|
||||
|
||||
if err := s.promoRepo.Create(ctx, promoCode); err != nil {
|
||||
return nil, fmt.Errorf("create promo code: %w", err)
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取优惠码
|
||||
func (s *PromoService) GetByID(ctx context.Context, id int64) (*PromoCode, error) {
|
||||
code, err := s.promoRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// Update 更新优惠码
|
||||
func (s *PromoService) Update(ctx context.Context, id int64, input *UpdatePromoCodeInput) (*PromoCode, error) {
|
||||
promoCode, err := s.promoRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.Code != nil {
|
||||
promoCode.Code = strings.ToUpper(strings.TrimSpace(*input.Code))
|
||||
}
|
||||
if input.BonusAmount != nil {
|
||||
promoCode.BonusAmount = *input.BonusAmount
|
||||
}
|
||||
if input.MaxUses != nil {
|
||||
promoCode.MaxUses = *input.MaxUses
|
||||
}
|
||||
if input.Status != nil {
|
||||
promoCode.Status = *input.Status
|
||||
}
|
||||
if input.ExpiresAt != nil {
|
||||
promoCode.ExpiresAt = input.ExpiresAt
|
||||
}
|
||||
if input.Notes != nil {
|
||||
promoCode.Notes = *input.Notes
|
||||
}
|
||||
|
||||
if err := s.promoRepo.Update(ctx, promoCode); err != nil {
|
||||
return nil, fmt.Errorf("update promo code: %w", err)
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// Delete 删除优惠码
|
||||
func (s *PromoService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.promoRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete promo code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取优惠码列表
|
||||
func (s *PromoService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error) {
|
||||
return s.promoRepo.ListWithFilters(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// ListUsages 获取使用记录
|
||||
func (s *PromoService) ListUsages(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
return s.promoRepo.ListUsagesByPromoCode(ctx, promoCodeID, params)
|
||||
}
|
||||
@@ -345,7 +345,7 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
||||
|
||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||
if status == "allowed" && account.IsRateLimited() {
|
||||
if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -353,7 +353,10 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
||||
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
|
||||
return s.accountRepo.ClearRateLimit(ctx, accountID)
|
||||
if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
@@ -64,6 +65,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyAPIBaseURL,
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocURL,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -71,6 +73,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
return nil, fmt.Errorf("get public settings: %w", err)
|
||||
}
|
||||
|
||||
linuxDoEnabled := false
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
linuxDoEnabled = raw == "true"
|
||||
} else {
|
||||
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
|
||||
}
|
||||
|
||||
return &PublicSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
@@ -82,6 +91,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -111,6 +121,14 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
|
||||
}
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled)
|
||||
updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID
|
||||
updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL
|
||||
if settings.LinuxDoConnectClientSecret != "" {
|
||||
updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret
|
||||
}
|
||||
|
||||
// OEM设置
|
||||
updates[SettingKeySiteName] = settings.SiteName
|
||||
updates[SettingKeySiteLogo] = settings.SiteLogo
|
||||
@@ -141,8 +159,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
||||
if err != nil {
|
||||
// 默认开放注册
|
||||
return true
|
||||
// 安全默认:如果设置不存在或查询出错,默认关闭注册
|
||||
return false
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
@@ -271,6 +289,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
// LinuxDo Connect 设置:
|
||||
// - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭)
|
||||
// - 支持在后台“系统设置”中覆盖并持久化(存储于 DB)
|
||||
linuxDoBase := config.LinuxDoConnectConfig{}
|
||||
if s.cfg != nil {
|
||||
linuxDoBase = s.cfg.LinuxDo
|
||||
}
|
||||
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
result.LinuxDoConnectEnabled = raw == "true"
|
||||
} else {
|
||||
result.LinuxDoConnectEnabled = linuxDoBase.Enabled
|
||||
}
|
||||
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
|
||||
result.LinuxDoConnectClientID = strings.TrimSpace(v)
|
||||
} else {
|
||||
result.LinuxDoConnectClientID = linuxDoBase.ClientID
|
||||
}
|
||||
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
|
||||
result.LinuxDoConnectRedirectURL = strings.TrimSpace(v)
|
||||
} else {
|
||||
result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL
|
||||
}
|
||||
|
||||
result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret])
|
||||
if result.LinuxDoConnectClientSecret == "" {
|
||||
result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret)
|
||||
}
|
||||
result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != ""
|
||||
|
||||
// Model fallback settings
|
||||
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
|
||||
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
|
||||
@@ -289,6 +339,99 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
return result
|
||||
}
|
||||
|
||||
// GetLinuxDoConnectOAuthConfig 返回用于登录的“最终生效” LinuxDo Connect 配置。
|
||||
//
|
||||
// 优先级:
|
||||
// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值
|
||||
// - 否则回退到 config.yaml/env 的值
|
||||
func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if s == nil || s.cfg == nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||
}
|
||||
|
||||
effective := s.cfg.LinuxDo
|
||||
|
||||
keys := []string{
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
SettingKeyLinuxDoConnectClientID,
|
||||
SettingKeyLinuxDoConnectClientSecret,
|
||||
SettingKeyLinuxDoConnectRedirectURL,
|
||||
}
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return config.LinuxDoConnectConfig{}, fmt.Errorf("get linuxdo connect settings: %w", err)
|
||||
}
|
||||
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
effective.Enabled = raw == "true"
|
||||
}
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.ClientID = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectClientSecret]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.ClientSecret = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.RedirectURL = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
if !effective.Enabled {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||
}
|
||||
|
||||
// 基础健壮性校验(避免把用户重定向到一个必然失败或不安全的 OAuth 流程里)。
|
||||
if strings.TrimSpace(effective.ClientID) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.AuthorizeURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.TokenURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.UserInfoURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.RedirectURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.FrontendRedirectURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured")
|
||||
}
|
||||
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.UserInfoURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid")
|
||||
}
|
||||
if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
|
||||
}
|
||||
|
||||
method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod))
|
||||
switch method {
|
||||
case "", "client_secret_post", "client_secret_basic":
|
||||
if strings.TrimSpace(effective.ClientSecret) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
|
||||
}
|
||||
case "none":
|
||||
if !effective.UsePKCE {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
|
||||
}
|
||||
default:
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
|
||||
}
|
||||
|
||||
return effective, nil
|
||||
}
|
||||
|
||||
// getStringOrDefault 获取字符串值或默认值
|
||||
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
|
||||
if value, ok := settings[key]; ok && value != "" {
|
||||
|
||||
@@ -18,6 +18,13 @@ type SystemSettings struct {
|
||||
TurnstileSecretKey string
|
||||
TurnstileSecretKeyConfigured bool
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
LinuxDoConnectEnabled bool
|
||||
LinuxDoConnectClientID string
|
||||
LinuxDoConnectClientSecret string
|
||||
LinuxDoConnectClientSecretConfigured bool
|
||||
LinuxDoConnectRedirectURL string
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
@@ -51,5 +58,6 @@ type PublicSettings struct {
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
LinuxDoOAuthEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ type UsageLog struct {
|
||||
DurationMs *int
|
||||
FirstTokenMs *int
|
||||
UserAgent *string
|
||||
IPAddress *string
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int
|
||||
|
||||
@@ -87,6 +87,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountService,
|
||||
NewProxyService,
|
||||
NewRedeemService,
|
||||
NewPromoService,
|
||||
NewUsageService,
|
||||
NewDashboardService,
|
||||
ProvidePricingService,
|
||||
|
||||
Reference in New Issue
Block a user