Compare commits

...

20 Commits

Author SHA1 Message Date
yangjianbo
d75cd820b0 fix(认证): 订阅兑换失效认证缓存
订阅兑换后同步失效认证缓存避免授权快照滞后
补充单测覆盖订阅兑换的失效场景

测试: go test ./... -tags=unit
2026-01-11 10:55:26 +08:00
yangjianbo
cb3e08dda4 fix(认证): 补齐余额与删除场景缓存失效
为 Usage/Promo/Redeem 注入认证缓存失效逻辑
删除用户与分组前先失效认证缓存降低窗口
补充回归测试验证失效调用

测试: make test
2026-01-11 10:55:25 +08:00
yangjianbo
44a93c1922 perf(认证): 引入 API Key 认证缓存与轻量删除查询
增加 L1/L2 缓存、负缓存与单飞回源
使用 key+owner 轻量查询替代全量加载并清理旧接口
补充缓存失效与余额更新测试,修复随机抖动 lint

测试: make test
2026-01-11 10:55:25 +08:00
Wesley Liddick
9cba595fd0 Merge pull request #233 from cyhhao/main
fix(openai): 对齐 OpenCode OAuth instructions,保持 Codex CLI 透明转发
2026-01-11 10:41:57 +08:00
shaw
56fc2764e4 chore: remove accidentally committed test binary 2026-01-11 10:37:09 +08:00
Wesley Liddick
0c4f1762c9 Merge pull request #232 from Edric-Li/feat/api-key-ip-restriction
feat(settings): 首页自定义内容 & 配置注入优化
2026-01-11 10:36:01 +08:00
cyhhao
80c1cdf024 fix(lint): trim unused codex helpers 2026-01-10 22:45:29 +08:00
Edric Li
0fa5a6015e feat(settings): add iframe CSP warning for home content
Add a warning message to inform admins that some websites may have
X-Frame-Options or CSP policies that prevent iframe embedding.
2026-01-10 22:35:33 +08:00
cyhhao
1a641392d9 Merge up/main 2026-01-10 21:57:57 +08:00
cyhhao
36b817d008 Align OAuth transform with OpenCode instructions 2026-01-10 20:53:16 +08:00
kzw200015
24d19a5f78 fix: 从codex请求参数中移除max_output_tokens (#231)
某些客户端比如 opencode 会在请求中附加 max_output_tokens,这会导致上游返回400错误
2026-01-10 19:37:04 +08:00
Edric Li
3fb4a2b0ff style: replace interface{} with any per golangci-lint rules 2026-01-10 19:08:41 +08:00
Edric Li
0772cdda0f fix: update API contract test for home_content field and fix gofmt 2026-01-10 19:01:00 +08:00
Edric Li
f6f072cb9a Merge branch 'main' into feat/api-key-ip-restriction 2026-01-10 18:49:50 +08:00
Edric Li
5265b12cc7 feat(settings): add home content customization and config injection
- Add home_content setting for custom homepage (HTML or iframe URL)
- Inject public settings into index.html to eliminate page flash
- Support ETag caching with automatic invalidation on settings update
- Add Vite plugin for dev mode settings injection
- Refactor HomeView to use appStore instead of local API calls
2026-01-10 18:37:44 +08:00
cyhhao
eb06006d6c Make Codex CLI passthrough 2026-01-10 03:12:56 +08:00
Edric Li
e83f644c3f fix: update API contract tests for ip_whitelist/ip_blacklist fields
Add ip_whitelist and ip_blacklist fields to expected JSON responses
in API contract tests to match the new API key schema.
2026-01-09 21:37:07 +08:00
Edric Li
6b97a8be28 Merge branch 'main' into feat/api-key-ip-restriction 2026-01-09 21:34:28 +08:00
Edric Li
90798f14b5 feat(api-key): add IP whitelist/blacklist restriction and usage log IP tracking
- Add IP restriction feature for API keys (whitelist/blacklist with CIDR support)
- Add IP address logging to usage logs (admin-only visibility)
- Remove billing_type column from usage logs UI (redundant)
- Use generic "Access denied" error message for security

Backend:
- New ip package with IP/CIDR validation and matching utilities
- Database migrations for ip_whitelist, ip_blacklist (api_keys) and ip_address (usage_logs)
- Middleware IP restriction check after API key validation
- Input validation for IP/CIDR patterns on create/update

Frontend:
- API key form with enable toggle for IP restriction
- Shield icon indicator in table for keys with IP restriction
- Removed billing_type filter and column from usage views
2026-01-09 21:24:59 +08:00
cyhhao
7a06c4873e Fix Codex OAuth tool mapping 2026-01-09 18:35:58 +08:00
54 changed files with 2912 additions and 303 deletions

View File

@@ -55,23 +55,24 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingCache := repository.NewBillingCache(redisClient)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewAPIKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardService := service.NewDashboardService(usageLogRepository)
@@ -79,7 +80,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
@@ -148,7 +149,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, redisClient)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, settingService, redisClient)
httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)

View File

@@ -44,11 +44,13 @@ require (
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgraph-io/ristretto v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect

View File

@@ -51,6 +51,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE=
github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -61,6 +63,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=

View File

@@ -36,25 +36,26 @@ const (
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
}
// UpdateConfig 在线更新相关配置
@@ -274,6 +275,13 @@ type DatabaseConfig struct {
}
func (d *DatabaseConfig) DSN() string {
// 当密码为空时不包含 password 参数,避免 libpq 解析错误
if d.Password == "" {
return fmt.Sprintf(
"host=%s port=%d user=%s dbname=%s sslmode=%s",
d.Host, d.Port, d.User, d.DBName, d.SSLMode,
)
}
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
@@ -285,6 +293,13 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
if tz == "" {
tz = "Asia/Shanghai"
}
// 当密码为空时不包含 password 参数,避免 libpq 解析错误
if d.Password == "" {
return fmt.Sprintf(
"host=%s port=%d user=%s dbname=%s sslmode=%s TimeZone=%s",
d.Host, d.Port, d.User, d.DBName, d.SSLMode, tz,
)
}
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz,
@@ -361,6 +376,16 @@ type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
}
// APIKeyAuthCacheConfig API Key 认证缓存配置
type APIKeyAuthCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
L2TTLSeconds int `mapstructure:"l2_ttl_seconds"`
NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
Singleflight bool `mapstructure:"singleflight"`
}
func NormalizeRunMode(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized {
@@ -655,6 +680,14 @@ func setDefaults() {
// Timezone (default to Asia/Shanghai for Chinese users)
viper.SetDefault("timezone", "Asia/Shanghai")
// API Key auth cache
viper.SetDefault("api_key_auth_cache.l1_size", 65535)
viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15)
viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300)
viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30)
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true)
// Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)

View File

@@ -62,6 +62,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
@@ -107,6 +108,7 @@ type UpdateSettingsRequest struct {
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -229,6 +231,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
@@ -277,6 +280,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
HomeContent: updatedSettings.HomeContent,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -377,6 +381,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DocURL != after.DocURL {
changed = append(changed, "doc_url")
}
if before.HomeContent != after.HomeContent {
changed = append(changed, "home_content")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}

View File

@@ -28,6 +28,7 @@ type SystemSettings struct {
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
@@ -55,6 +56,7 @@ type PublicSettings struct {
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -93,19 +94,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
// For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent")
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
if !openai.IsCodexCLIRequest(userAgent) {
reqBody["instructions"] = openai.DefaultInstructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) == "" {
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
reqBody["instructions"] = instructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
}
}

View File

@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
})

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
@@ -13,6 +14,7 @@ import (
const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
apiKeyAuthCachePrefix = "apikey:auth:"
)
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
@@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
}
func apiKeyAuthCacheKey(key string) string {
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
}
type apiKeyCache struct {
rdb *redis.Client
}
@@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err()
}
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
if err != nil {
return nil, err
}
var entry service.APIKeyAuthCacheEntry
if err := json.Unmarshal(val, &entry); err != nil {
return nil, err
}
return &entry, nil
}
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
if entry == nil {
return nil
}
payload, err := json.Marshal(entry)
if err != nil {
return err
}
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
}
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
}

View File

@@ -6,7 +6,9 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -64,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
return apiKeyEntityToService(m), nil
}
// GetOwnerID 根据 API Key ID 获取其所有者(用户)ID。
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者用户ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 使用 Select() 只查询必要字段,减少数据传输量
// - 不加载完整的 API Key 实体及其关联数据User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
// - 适用于删除等只需 key 与用户 ID 的场景
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldUserID).
Select(apikey.FieldKey, apikey.FieldUserID).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
return "", 0, service.ErrAPIKeyNotFound
}
return 0, err
return "", 0, err
}
return m.UserID, nil
return m.Key, m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
@@ -98,6 +100,54 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
Select(
apikey.FieldID,
apikey.FieldUserID,
apikey.FieldGroupID,
apikey.FieldStatus,
apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
q.Select(
group.FieldID,
group.FieldName,
group.FieldPlatform,
group.FieldStatus,
group.FieldSubscriptionType,
group.FieldRateMultiplier,
group.FieldDailyLimitUsd,
group.FieldWeeklyLimitUsd,
group.FieldMonthlyLimitUsd,
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
)
}).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
@@ -283,6 +333,28 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err
}
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.UserIDEQ(userID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.GroupIDEQ(groupID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil

View File

@@ -326,7 +326,8 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": ""
"identity_patch_prompt": "",
"home_content": ""
}
}`,
},
@@ -389,11 +390,11 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo)
userService := service.NewUserService(userRepo, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil)
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
@@ -565,6 +566,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t
return nil
}
func (stubApiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (stubApiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
return nil
}
func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return nil
}
type stubGroupRepo struct{}
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
@@ -737,12 +750,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
return &clone, nil
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
key, ok := r.byID[id]
if !ok {
return 0, service.ErrAPIKeyNotFound
return "", 0, service.ErrAPIKeyNotFound
}
return key.UserID, nil
return key.Key, key.UserID, nil
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
@@ -754,6 +767,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
return &clone, nil
}
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
return r.GetByKey(ctx, key)
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
if key == nil {
return errors.New("nil key")
@@ -868,6 +885,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog
}

View File

@@ -31,6 +31,7 @@ func ProvideRouter(
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
settingService *service.SettingService,
redisClient *redis.Client,
) *gin.Engine {
if cfg.Server.Mode == "release" {
@@ -49,7 +50,7 @@ func ProvideRouter(
}
}
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, settingService, cfg, redisClient)
}
// ProvideHTTPServer 提供 HTTP 服务器

View File

@@ -27,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
return "", 0, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if f.getByKey == nil {
@@ -36,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK
}
return f.getByKey(ctx, key)
}
func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
return f.GetByKey(ctx, key)
}
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
@@ -66,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
type googleErrorResponse struct {
Error struct {

View File

@@ -256,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
return "", 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
@@ -267,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
return r.GetByKey(ctx, key)
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
@@ -307,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error

View File

@@ -1,6 +1,8 @@
package server
import (
"log"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -21,6 +23,7 @@ func SetupRouter(
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
settingService *service.SettingService,
cfg *config.Config,
redisClient *redis.Client,
) *gin.Engine {
@@ -29,9 +32,17 @@ func SetupRouter(
r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
// Serve embedded frontend if available
// Serve embedded frontend with settings injection if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
frontendServer, err := web.NewFrontendServer(settingService)
if err != nil {
log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err)
r.Use(web.ServeEmbeddedFrontend())
} else {
// Register cache invalidation callback
settingService.SetOnUpdateCallback(frontendServer.InvalidateCache)
r.Use(frontendServer.Middleware())
}
}
// 注册路由

View File

@@ -244,14 +244,15 @@ type ProxyExitInfoProber interface {
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewAdminService creates a new AdminService
@@ -264,16 +265,18 @@ func NewAdminService(
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -323,6 +326,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
if input.Email != "" {
user.Email = input.Email
@@ -355,6 +360,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
}
}
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
@@ -393,6 +403,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
log.Printf("delete user failed: user_id=%d err=%v", id, err)
return err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id)
}
return nil
}
@@ -420,6 +433,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
balanceDiff := user.Balance - oldBalance
if s.authCacheInvalidator != nil && balanceDiff != 0 {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService != nil {
go func() {
@@ -431,7 +448,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
}()
}
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
@@ -675,10 +691,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
var groupKeys []string
if s.authCacheInvalidator != nil {
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id)
if err == nil {
groupKeys = keys
}
}
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
if err != nil {
return err
@@ -697,6 +724,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
}
}()
}
if s.authCacheInvalidator != nil {
for _, key := range groupKeys {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key)
}
}
return nil
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type balanceUserRepoStub struct {
*userRepoStub
updateErr error
updated []*User
}
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
if s.updateErr != nil {
return s.updateErr
}
if user == nil {
return nil
}
clone := *user
s.updated = append(s.updated, &clone)
if s.userRepoStub != nil {
s.userRepoStub.user = &clone
}
return nil
}
type balanceRedeemRepoStub struct {
*redeemRepoStub
created []*RedeemCode
}
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
if code == nil {
return nil
}
clone := *code
s.created = append(s.created, &clone)
return nil
}
type authCacheInvalidatorStub struct {
userIDs []int64
groupIDs []int64
keys []string
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
s.keys = append(s.keys, key)
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
s.userIDs = append(s.userIDs, userID)
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
s.groupIDs = append(s.groupIDs, groupID)
}
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: redeemRepo,
authCacheInvalidator: invalidator,
}
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
require.NoError(t, err)
require.Equal(t, []int64{7}, invalidator.userIDs)
require.Len(t, redeemRepo.created, 1)
}
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: redeemRepo,
authCacheInvalidator: invalidator,
}
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
require.NoError(t, err)
require.Empty(t, invalidator.userIDs)
require.Empty(t, redeemRepo.created)
}

View File

@@ -0,0 +1,46 @@
package service
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct {
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist,omitempty"`
IPBlacklist []string `json:"ip_blacklist,omitempty"`
User APIKeyAuthUserSnapshot `json:"user"`
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
}
// APIKeyAuthUserSnapshot 用户快照
type APIKeyAuthUserSnapshot struct {
ID int64 `json:"id"`
Status string `json:"status"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
}
// APIKeyAuthGroupSnapshot 分组快照
type APIKeyAuthGroupSnapshot struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
type APIKeyAuthCacheEntry struct {
NotFound bool `json:"not_found"`
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
}

View File

@@ -0,0 +1,269 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/dgraph-io/ristretto"
)
type apiKeyAuthCacheConfig struct {
l1Size int
l1TTL time.Duration
l2TTL time.Duration
negativeTTL time.Duration
jitterPercent int
singleflight bool
}
var (
jitterRandMu sync.Mutex
// 认证缓存抖动使用独立随机源,避免全局 Seed
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
)
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
if cfg == nil {
return apiKeyAuthCacheConfig{}
}
auth := cfg.APIKeyAuth
return apiKeyAuthCacheConfig{
l1Size: auth.L1Size,
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
jitterPercent: auth.JitterPercent,
singleflight: auth.Singleflight,
}
}
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
return c.l1Size > 0 && c.l1TTL > 0
}
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
return c.l2TTL > 0
}
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return c.negativeTTL > 0
}
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return ttl
}
if c.jitterPercent <= 0 {
return ttl
}
percent := c.jitterPercent
if percent > 100 {
percent = 100
}
delta := float64(percent) / 100
jitterRandMu.Lock()
randVal := jitterRand.Float64()
jitterRandMu.Unlock()
factor := 1 - delta + randVal*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
if !s.authCfg.l1Enabled() {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(s.authCfg.l1Size) * 10,
MaxCost: int64(s.authCfg.l1Size),
BufferItems: 64,
})
if err != nil {
return
}
s.authCacheL1 = cache
}
func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
}
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
if s.authCacheL1 != nil {
if val, ok := s.authCacheL1.Get(cacheKey); ok {
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
return entry, true
}
}
}
if s.cache == nil || !s.authCfg.l2Enabled() {
return nil, false
}
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
if err != nil {
return nil, false
}
s.setAuthCacheL1(cacheKey, entry)
return entry, true
}
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
if s.authCacheL1 == nil || entry == nil {
return
}
ttl := s.authCfg.l1TTL
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
ttl = s.authCfg.negativeTTL
}
ttl = s.authCfg.jitterTTL(ttl)
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
}
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
if entry == nil {
return
}
s.setAuthCacheL1(cacheKey, entry)
if s.cache == nil || !s.authCfg.l2Enabled() {
return
}
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
}
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
if s.authCacheL1 != nil {
s.authCacheL1.Del(cacheKey)
}
if s.cache == nil {
return
}
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
}
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
if errors.Is(err, ErrAPIKeyNotFound) {
entry := &APIKeyAuthCacheEntry{NotFound: true}
if s.authCfg.negativeEnabled() {
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
}
return entry, nil
}
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
snapshot := s.snapshotFromAPIKey(apiKey)
if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
}
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
return entry, nil
}
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
if entry == nil {
return nil, false, nil
}
if entry.NotFound {
return nil, true, ErrAPIKeyNotFound
}
if entry.Snapshot == nil {
return nil, false, nil
}
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil {
return nil
}
snapshot := &APIKeyAuthSnapshot{
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
Status: apiKey.Status,
IPWhitelist: apiKey.IPWhitelist,
IPBlacklist: apiKey.IPBlacklist,
User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID,
Status: apiKey.User.Status,
Role: apiKey.User.Role,
Balance: apiKey.User.Balance,
Concurrency: apiKey.User.Concurrency,
},
}
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
}
}
return snapshot
}
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
if snapshot == nil {
return nil
}
apiKey := &APIKey{
ID: snapshot.APIKeyID,
UserID: snapshot.UserID,
GroupID: snapshot.GroupID,
Key: key,
Status: snapshot.Status,
IPWhitelist: snapshot.IPWhitelist,
IPBlacklist: snapshot.IPBlacklist,
User: &User{
ID: snapshot.User.ID,
Status: snapshot.User.Status,
Role: snapshot.User.Role,
Balance: snapshot.User.Balance,
Concurrency: snapshot.User.Concurrency,
},
}
if snapshot.Group != nil {
apiKey.Group = &Group{
ID: snapshot.Group.ID,
Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status,
Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
}
}
return apiKey
}

View File

@@ -0,0 +1,48 @@
package service
import "context"
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
if key == "" {
return
}
cacheKey := s.authCacheKey(key)
s.deleteAuthCache(ctx, cacheKey)
}
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
if userID <= 0 {
return
}
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
if err != nil {
return
}
s.deleteAuthCacheByKeys(ctx, keys)
}
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
if groupID <= 0 {
return
}
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
if err != nil {
return
}
s.deleteAuthCacheByKeys(ctx, keys)
}
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
if len(keys) == 0 {
return
}
for _, key := range keys {
if key == "" {
continue
}
s.deleteAuthCache(ctx, s.authCacheKey(key))
}
}

View File

@@ -12,6 +12,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
)
var (
@@ -31,9 +33,11 @@ const (
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID用于删除等轻量场景
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
GetByKey(ctx context.Context, key string) (*APIKey, error)
// GetByKeyForAuth 认证专用查询,返回最小字段集
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
@@ -45,6 +49,8 @@ type APIKeyRepository interface {
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
}
// APIKeyCache defines cache operations for API key service
@@ -55,6 +61,17 @@ type APIKeyCache interface {
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
DeleteAuthCache(ctx context.Context, key string) error
}
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
type APIKeyAuthCacheInvalidator interface {
InvalidateAuthCacheByKey(ctx context.Context, key string)
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
}
// CreateAPIKeyRequest 创建API Key请求
@@ -83,6 +100,9 @@ type APIKeyService struct {
userSubRepo UserSubscriptionRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
@@ -94,7 +114,7 @@ func NewAPIKeyService(
cache APIKeyCache,
cfg *config.Config,
) *APIKeyService {
return &APIKeyService{
svc := &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
@@ -102,6 +122,8 @@ func NewAPIKeyService(
cache: cache,
cfg: cfg,
}
svc.initAuthCache(cfg)
return svc
}
// GenerateKey 生成随机API Key
@@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("create api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
return apiKey, nil
}
@@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetByKey 根据Key字符串获取API Key用于认证
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
cacheKey := s.authCacheKey(key)
// 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
}
if s.authCfg.singleflight {
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
return s.loadAuthCacheEntry(ctx, key, cacheKey)
})
if err != nil {
return nil, err
}
entry, _ := value.(*APIKeyAuthCacheEntry)
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
} else {
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
if err != nil {
return nil, err
}
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
}
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
// 缓存到Redis可选TTL设置为5分钟
if s.cache != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
}
apiKey.Key = key
return apiKey, nil
}
@@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
return nil, fmt.Errorf("update api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
return apiKey, nil
}
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 APIKey 对象及其关联数据User、Group提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
if err != nil {
return fmt.Errorf("get api key: %w", err)
}
@@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
return ErrInsufficientPerms
}
// 清除Redis缓存使用 ownerID 而非 apiKey.UserID
// 清除Redis缓存使用 userID 而非 apiKey.UserID
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
}
s.InvalidateAuthCacheByKey(ctx, key)
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)

View File

@@ -0,0 +1,417 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
type authRepoStub struct {
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
}
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call")
}
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call")
}
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
panic("unexpected GetKeyAndOwnerID call")
}
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
if s.getByKeyForAuth == nil {
panic("unexpected GetByKeyForAuth call")
}
return s.getByKeyForAuth(ctx, key)
}
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
if s.listKeysByUserID == nil {
panic("unexpected ListKeysByUserID call")
}
return s.listKeysByUserID(ctx, userID)
}
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
if s.listKeysByGroupID == nil {
panic("unexpected ListKeysByGroupID call")
}
return s.listKeysByGroupID(ctx, groupID)
}
type authCacheStub struct {
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
setAuthKeys []string
deleteAuthKeys []string
}
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil
}
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil
}
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
if s.getAuthCache == nil {
return nil, redis.Nil
}
return s.getAuthCache(ctx, key)
}
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
s.setAuthKeys = append(s.setAuthKeys, key)
return nil
}
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, errors.New("unexpected repo call")
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: APIKeyAuthUserSnapshot{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &APIKeyAuthGroupSnapshot{
ID: groupID,
Name: "g",
Platform: PlatformAnthropic,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
},
},
}
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return cacheEntry, nil
}
apiKey, err := svc.GetByKey(context.Background(), "k1")
require.NoError(t, err)
require.Equal(t, int64(1), apiKey.ID)
require.Equal(t, int64(2), apiKey.User.ID)
require.Equal(t, groupID, apiKey.Group.ID)
}
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, errors.New("unexpected repo call")
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil
}
_, err := svc.GetByKey(context.Background(), "missing")
require.ErrorIs(t, err, ErrAPIKeyNotFound)
}
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return &APIKey{
ID: 5,
UserID: 7,
Status: StatusActive,
User: &User{
ID: 7,
Status: StatusActive,
Role: RoleUser,
Balance: 12,
Concurrency: 2,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
apiKey, err := svc.GetByKey(context.Background(), "k2")
require.NoError(t, err)
require.Equal(t, int64(5), apiKey.ID)
require.Len(t, cache.setAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
var calls int32
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&calls, 1)
return &APIKey{
ID: 21,
UserID: 3,
Status: StatusActive,
User: &User{
ID: 3,
Status: StatusActive,
Role: RoleUser,
Balance: 5,
Concurrency: 2,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L1Size: 1000,
L1TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1")
require.NoError(t, err)
svc.authCacheL1.Wait()
cacheKey := svc.authCacheKey("k-l1")
_, ok := svc.authCacheL1.Get(cacheKey)
require.True(t, ok)
_, err = svc.GetByKey(context.Background(), "k-l1")
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
}
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
return []string{"k1", "k2"}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2)
}
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
return []string{"k1", "k2"}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2)
}
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
return nil, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, ErrAPIKeyNotFound
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
_, err := svc.GetByKey(context.Background(), "missing")
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Len(t, cache.setAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
var calls int32
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&calls, 1)
time.Sleep(50 * time.Millisecond)
return &APIKey{
ID: 11,
UserID: 2,
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 1,
Concurrency: 1,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
Singleflight: true,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
start := make(chan struct{})
wg := sync.WaitGroup{}
errs := make([]error, 5)
for i := 0; i < 5; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
<-start
_, err := svc.GetByKey(context.Background(), "k1")
errs[idx] = err
}(i)
}
close(start)
wg.Wait()
for _, err := range errs {
require.NoError(t, err)
}
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
}

View File

@@ -20,13 +20,12 @@ import (
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
ownerID int64 // GetOwnerID 的返回值
ownerErr error // GetOwnerID 的错误返回值
apiKey *APIKey // GetKeyAndOwnerID 的返回值
getByIDErr error // GetKeyAndOwnerID 的错误返回值
deleteErr error // Delete 的错误返回值
deletedIDs []int64 // 记录已删除的 API Key ID 列表
}
@@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
if s.getByIDErr != nil {
return nil, s.getByIDErr
}
if s.apiKey != nil {
clone := *s.apiKey
return &clone, nil
}
panic("unexpected GetByID call")
}
// GetOwnerID 返回预设的所有者 ID 或错误。
// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。
func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return s.ownerID, s.ownerErr
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
if s.getByIDErr != nil {
return "", 0, s.getByIDErr
}
if s.apiKey != nil {
return s.apiKey.Key, s.apiKey.UserID, nil
}
return "", 0, ErrAPIKeyNotFound
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKeyForAuth call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
@@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call")
}
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
panic("unexpected ListKeysByUserID call")
}
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
panic("unexpected ListKeysByGroupID call")
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
// - invalidated: 记录被清除缓存的用户 ID 列表
type apiKeyCacheStub struct {
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
}
// GetCreateAttemptCount 返回 0表示用户未超过创建次数限制
@@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return nil
}
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
return nil
}
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 1
// - GetKeyAndOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2不匹配
// - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1}
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
require.ErrorIs(t, err, ErrInsufficientPerms)
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
require.Empty(t, cache.invalidated) // 验证缓存未被清除
require.Empty(t, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 7
// - GetKeyAndOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7匹配
// - Delete 成功执行
// - 缓存被正确清除(使用 ownerID
// - 返回 nil 错误
func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7}
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) {
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
require.Empty(t, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为:
// - GetOwnerID 返回正确的所有者 ID
// - GetKeyAndOwnerID 返回正确的所有者 ID
// - 所有权验证通过
// - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{
ownerID: 3,
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
@@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
require.ErrorContains(t, err, "delete api key")
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
}

View File

@@ -0,0 +1,33 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
invalidator := &authCacheInvalidatorStub{}
svc := &UsageService{authCacheInvalidator: invalidator}
svc.invalidateUsageCaches(context.Background(), 7, false)
require.Empty(t, invalidator.userIDs)
svc.invalidateUsageCaches(context.Background(), 7, true)
require.Equal(t, []int64{7}, invalidator.userIDs)
}
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
invalidator := &authCacheInvalidatorStub{}
svc := &RedeemService{authCacheInvalidator: invalidator}
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
groupID := int64(3)
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID})
require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs)
}

View File

@@ -90,6 +90,7 @@ const (
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML或 URL 作为 iframe src
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量

View File

@@ -50,13 +50,15 @@ type UpdateGroupRequest struct {
// GroupService 分组管理服务
type GroupService struct {
groupRepo GroupRepository
groupRepo GroupRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewGroupService 创建分组服务实例
func NewGroupService(groupRepo GroupRepository) *GroupService {
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
return &GroupService{
groupRepo: groupRepo,
groupRepo: groupRepo,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, fmt.Errorf("update group: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
@@ -167,6 +172,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
return fmt.Errorf("get group: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete group: %w", err)
}

View File

@@ -0,0 +1,528 @@
package service
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const (
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
codexCacheTTL = 15 * time.Minute
)
var codexModelMap = map[string]string{
"gpt-5.1-codex": "gpt-5.1-codex",
"gpt-5.1-codex-low": "gpt-5.1-codex",
"gpt-5.1-codex-medium": "gpt-5.1-codex",
"gpt-5.1-codex-high": "gpt-5.1-codex",
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
"gpt-5.2": "gpt-5.2",
"gpt-5.2-none": "gpt-5.2",
"gpt-5.2-low": "gpt-5.2",
"gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2",
"gpt-5.2-codex": "gpt-5.2-codex",
"gpt-5.2-codex-low": "gpt-5.2-codex",
"gpt-5.2-codex-medium": "gpt-5.2-codex",
"gpt-5.2-codex-high": "gpt-5.2-codex",
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5.1": "gpt-5.1",
"gpt-5.1-none": "gpt-5.1",
"gpt-5.1-low": "gpt-5.1",
"gpt-5.1-medium": "gpt-5.1",
"gpt-5.1-high": "gpt-5.1",
"gpt-5.1-chat-latest": "gpt-5.1",
"gpt-5-codex": "gpt-5.1-codex",
"codex-mini-latest": "gpt-5.1-codex-mini",
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5": "gpt-5.1",
"gpt-5-mini": "gpt-5.1",
"gpt-5-nano": "gpt-5.1",
}
type codexTransformResult struct {
Modified bool
NormalizedModel string
PromptCacheKey string
}
type opencodeCacheMetadata struct {
ETag string `json:"etag"`
LastFetch string `json:"lastFetch,omitempty"`
LastChecked int64 `json:"lastChecked"`
}
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result := codexTransformResult{}
model := ""
if v, ok := reqBody["model"].(string); ok {
model = v
}
normalizedModel := normalizeCodexModel(model)
if normalizedModel != "" {
if model != normalizedModel {
reqBody["model"] = normalizedModel
result.Modified = true
}
result.NormalizedModel = normalizedModel
}
if v, ok := reqBody["store"].(bool); !ok || v {
reqBody["store"] = false
result.Modified = true
}
if v, ok := reqBody["stream"].(bool); !ok || !v {
reqBody["stream"] = true
result.Modified = true
}
if _, ok := reqBody["max_output_tokens"]; ok {
delete(reqBody, "max_output_tokens")
result.Modified = true
}
if _, ok := reqBody["max_completion_tokens"]; ok {
delete(reqBody, "max_completion_tokens")
result.Modified = true
}
if normalizeCodexTools(reqBody) {
result.Modified = true
}
if v, ok := reqBody["prompt_cache_key"].(string); ok {
result.PromptCacheKey = strings.TrimSpace(v)
}
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
existingInstructions, _ := reqBody["instructions"].(string)
existingInstructions = strings.TrimSpace(existingInstructions)
if instructions != "" {
if existingInstructions != "" && existingInstructions != instructions {
if input, ok := reqBody["input"].([]any); ok {
reqBody["input"] = prependSystemInstruction(input, existingInstructions)
result.Modified = true
}
}
if existingInstructions != instructions {
reqBody["instructions"] = instructions
result.Modified = true
}
}
if input, ok := reqBody["input"].([]any); ok {
input = filterCodexInput(input)
input = normalizeOrphanedToolOutputs(input)
reqBody["input"] = input
result.Modified = true
}
return result
}
func normalizeCodexModel(model string) string {
if model == "" {
return "gpt-5.1"
}
modelID := model
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
return mapped
}
normalized := strings.ToLower(modelID)
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
return "gpt-5.2-codex"
}
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
return "gpt-5.2"
}
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
return "gpt-5.1-codex-max"
}
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
return "gpt-5.1-codex-mini"
}
if strings.Contains(normalized, "codex-mini-latest") ||
strings.Contains(normalized, "gpt-5-codex-mini") ||
strings.Contains(normalized, "gpt 5 codex mini") {
return "codex-mini-latest"
}
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
return "gpt-5.1-codex"
}
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
return "gpt-5.1"
}
if strings.Contains(normalized, "codex") {
return "gpt-5.1-codex"
}
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
return "gpt-5.1"
}
return "gpt-5.1"
}
func getNormalizedCodexModel(modelID string) string {
if modelID == "" {
return ""
}
if mapped, ok := codexModelMap[modelID]; ok {
return mapped
}
lower := strings.ToLower(modelID)
for key, value := range codexModelMap {
if strings.ToLower(key) == lower {
return value
}
}
return ""
}
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
cacheDir := codexCachePath("")
if cacheDir == "" {
return ""
}
cacheFile := filepath.Join(cacheDir, cacheFileName)
metaFile := filepath.Join(cacheDir, metaFileName)
var cachedContent string
if content, ok := readFile(cacheFile); ok {
cachedContent = content
}
var meta opencodeCacheMetadata
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
return cachedContent
}
}
content, etag, status, err := fetchWithETag(url, meta.ETag)
if err == nil && status == http.StatusNotModified && cachedContent != "" {
return cachedContent
}
if err == nil && status >= 200 && status < 300 && content != "" {
_ = writeFile(cacheFile, content)
meta = opencodeCacheMetadata{
ETag: etag,
LastFetch: time.Now().UTC().Format(time.RFC3339),
LastChecked: time.Now().UnixMilli(),
}
_ = writeJSON(metaFile, meta)
return content
}
return cachedContent
}
func getOpenCodeCodexHeader() string {
return getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
}
func GetOpenCodeInstructions() string {
return getOpenCodeCodexHeader()
}
func filterCodexInput(input []any) []any {
filtered := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok {
filtered = append(filtered, item)
continue
}
if typ, ok := m["type"].(string); ok && typ == "item_reference" {
continue
}
delete(m, "id")
filtered = append(filtered, m)
}
return filtered
}
func prependSystemInstruction(input []any, instructions string) []any {
message := map[string]any{
"role": "system",
"content": []any{
map[string]any{
"type": "input_text",
"text": instructions,
},
},
}
return append([]any{message}, input...)
}
func normalizeCodexTools(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
return false
}
tools, ok := rawTools.([]any)
if !ok {
return false
}
modified := false
for idx, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
continue
}
toolType, _ := toolMap["type"].(string)
if strings.TrimSpace(toolType) != "function" {
continue
}
function, ok := toolMap["function"].(map[string]any)
if !ok {
continue
}
if _, ok := toolMap["name"]; !ok {
if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" {
toolMap["name"] = name
modified = true
}
}
if _, ok := toolMap["description"]; !ok {
if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" {
toolMap["description"] = desc
modified = true
}
}
if _, ok := toolMap["parameters"]; !ok {
if params, ok := function["parameters"]; ok {
toolMap["parameters"] = params
modified = true
}
}
if _, ok := toolMap["strict"]; !ok {
if strict, ok := function["strict"]; ok {
toolMap["strict"] = strict
modified = true
}
}
tools[idx] = toolMap
}
if modified {
reqBody["tools"] = tools
}
return modified
}
func normalizeOrphanedToolOutputs(input []any) []any {
functionCallIDs := map[string]bool{}
localShellCallIDs := map[string]bool{}
customToolCallIDs := map[string]bool{}
for _, item := range input {
m, ok := item.(map[string]any)
if !ok {
continue
}
callID := getCallID(m)
if callID == "" {
continue
}
switch m["type"] {
case "function_call":
functionCallIDs[callID] = true
case "local_shell_call":
localShellCallIDs[callID] = true
case "custom_tool_call":
customToolCallIDs[callID] = true
}
}
output := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok {
output = append(output, item)
continue
}
switch m["type"] {
case "function_call_output":
callID := getCallID(m)
if callID == "" || (!functionCallIDs[callID] && !localShellCallIDs[callID]) {
output = append(output, convertOrphanedOutputToMessage(m, callID))
continue
}
case "custom_tool_call_output":
callID := getCallID(m)
if callID == "" || !customToolCallIDs[callID] {
output = append(output, convertOrphanedOutputToMessage(m, callID))
continue
}
case "local_shell_call_output":
callID := getCallID(m)
if callID == "" || !localShellCallIDs[callID] {
output = append(output, convertOrphanedOutputToMessage(m, callID))
continue
}
}
output = append(output, m)
}
return output
}
func getCallID(item map[string]any) string {
raw, ok := item["call_id"]
if !ok {
return ""
}
callID, ok := raw.(string)
if !ok {
return ""
}
callID = strings.TrimSpace(callID)
if callID == "" {
return ""
}
return callID
}
func convertOrphanedOutputToMessage(item map[string]any, callID string) map[string]any {
toolName := "tool"
if name, ok := item["name"].(string); ok && name != "" {
toolName = name
}
labelID := callID
if labelID == "" {
labelID = "unknown"
}
text := stringifyOutput(item["output"])
if len(text) > 16000 {
text = text[:16000] + "\n...[truncated]"
}
return map[string]any{
"type": "message",
"role": "assistant",
"content": fmt.Sprintf("[Previous %s result; call_id=%s]: %s", toolName, labelID, text),
}
}
func stringifyOutput(output any) string {
switch v := output.(type) {
case string:
return v
default:
if data, err := json.Marshal(v); err == nil {
return string(data)
}
return fmt.Sprintf("%v", v)
}
}
func codexCachePath(filename string) string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
cacheDir := filepath.Join(home, ".opencode", "cache")
if filename == "" {
return cacheDir
}
return filepath.Join(cacheDir, filename)
}
func readFile(path string) (string, bool) {
if path == "" {
return "", false
}
data, err := os.ReadFile(path)
if err != nil {
return "", false
}
return string(data), true
}
func writeFile(path, content string) error {
if path == "" {
return fmt.Errorf("empty cache path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
return os.WriteFile(path, []byte(content), 0o644)
}
func loadJSON(path string, target any) bool {
data, err := os.ReadFile(path)
if err != nil {
return false
}
if err := json.Unmarshal(data, target); err != nil {
return false
}
return true
}
func writeJSON(path string, value any) error {
if path == "" {
return fmt.Errorf("empty json path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func fetchWithETag(url, etag string) (string, string, int, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", "", 0, err
}
req.Header.Set("User-Agent", "sub2api-codex")
if etag != "" {
req.Header.Set("If-None-Match", etag)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", "", 0, err
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", resp.StatusCode, err
}
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
}

View File

@@ -12,6 +12,7 @@ import (
"io"
"log"
"net/http"
"os"
"regexp"
"sort"
"strconv"
@@ -20,6 +21,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
@@ -528,31 +530,38 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract model and stream from parsed body
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
promptCacheKey := ""
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
}
// Track if body needs re-serialization
bodyModified := false
originalModel := reqModel
// Apply model mapping
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
reqBody["model"] = mappedModel
bodyModified = true
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
// Apply model mapping (skip for Codex CLI for transparent forwarding)
mappedModel := reqModel
if !isCodexCLI {
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel {
reqBody["model"] = mappedModel
bodyModified = true
}
}
// For OAuth accounts using ChatGPT internal API:
// 1. Add store: false
// 2. Normalize input format for Codex API compatibility
if account.Type == AccountTypeOAuth {
reqBody["store"] = false
bodyModified = true
// Normalize input format: convert AI SDK multi-part content format to simplified format
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
// Codex API expects: {"content": "..."}
if normalizeInputForCodexAPI(reqBody) {
if account.Type == AccountTypeOAuth && !isCodexCLI {
codexResult := applyCodexOAuthTransform(reqBody)
if codexResult.Modified {
bodyModified = true
}
if codexResult.NormalizedModel != "" {
mappedModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
}
}
// Re-serialize body only if modified
@@ -571,7 +580,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
if err != nil {
return nil, err
}
@@ -632,7 +641,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}, nil
}
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
// Determine target URL based on account type
var targetURL string
switch account.Type {
@@ -672,12 +681,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
// Set accept header based on stream mode
if isStream {
req.Header.Set("accept", "text/event-stream")
} else {
req.Header.Set("accept", "application/json")
}
}
// Whitelist passthrough headers
@@ -689,6 +692,22 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
}
}
}
if account.Type == AccountTypeOAuth {
req.Header.Set("OpenAI-Beta", "responses=experimental")
if isCodexCLI {
req.Header.Set("originator", "codex_cli_rs")
} else {
req.Header.Set("originator", "opencode")
}
req.Header.Set("accept", "text/event-stream")
if promptCacheKey != "" {
req.Header.Set("conversation_id", promptCacheKey)
req.Header.Set("session_id", promptCacheKey)
} else {
req.Header.Del("conversation_id")
req.Header.Del("session_id")
}
}
// Apply custom User-Agent if configured
customUA := account.GetOpenAIUserAgent()
@@ -706,6 +725,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
logUpstreamErrorBody(account.ID, resp.StatusCode, body)
// Check custom error codes
if !account.ShouldHandleErrorCode(resp.StatusCode) {
@@ -764,6 +784,24 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
func logUpstreamErrorBody(accountID int64, statusCode int, body []byte) {
if strings.ToLower(strings.TrimSpace(os.Getenv("GATEWAY_LOG_UPSTREAM_ERROR_BODY"))) != "true" {
return
}
maxBytes := 2048
if rawMax := strings.TrimSpace(os.Getenv("GATEWAY_LOG_UPSTREAM_ERROR_BODY_MAX_BYTES")); rawMax != "" {
if parsed, err := strconv.Atoi(rawMax); err == nil && parsed > 0 {
maxBytes = parsed
}
}
if len(body) > maxBytes {
body = body[:maxBytes]
}
log.Printf("Upstream error body: account=%d status=%d body=%q", accountID, statusCode, string(body))
}
// openaiStreamingResult streaming response result
type openaiStreamingResult struct {
usage *OpenAIUsage
@@ -1016,6 +1054,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return nil, err
}
if account.Type == AccountTypeOAuth {
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
}
}
// Parse usage
var response struct {
Usage struct {
@@ -1055,6 +1100,110 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return usage, nil
}
func isEventStreamResponse(header http.Header) bool {
contentType := strings.ToLower(header.Get("Content-Type"))
return strings.Contains(contentType, "text/event-stream")
}
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
usage := &OpenAIUsage{}
if ok {
var response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
}
if err := json.Unmarshal(finalResponse, &response); err == nil {
usage.InputTokens = response.Usage.InputTokens
usage.OutputTokens = response.Usage.OutputTokens
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
}
body = finalResponse
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
} else {
usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != mappedModel {
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
}
body = []byte(bodyText)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json; charset=utf-8"
if !ok {
contentType = resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream"
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
var event struct {
Type string `json:"type"`
Response json.RawMessage `json:"response"`
}
if json.Unmarshal([]byte(data), &event) != nil {
continue
}
if event.Type == "response.done" || event.Type == "response.completed" {
if len(event.Response) > 0 {
return event.Response, true
}
}
}
return nil, false
}
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{}
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
s.parseSSEUsage(data, usage)
}
return usage
}
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
lines := strings.Split(body, "\n")
for i, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
}
return strings.Join(lines, "\n")
}
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
@@ -1094,101 +1243,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
return newBody
}
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
// that the ChatGPT internal Codex API expects.
//
// AI SDK sends content as an array of typed objects:
//
// {"content": [{"type": "input_text", "text": "hello"}]}
//
// ChatGPT Codex API expects content as a simple string:
//
// {"content": "hello"}
//
// This function modifies reqBody in-place and returns true if any modification was made.
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
input, ok := reqBody["input"]
if !ok {
return false
}
// Handle case where input is a simple string (already compatible)
if _, isString := input.(string); isString {
return false
}
// Handle case where input is an array of messages
inputArray, ok := input.([]any)
if !ok {
return false
}
modified := false
for _, item := range inputArray {
message, ok := item.(map[string]any)
if !ok {
continue
}
content, ok := message["content"]
if !ok {
continue
}
// If content is already a string, no conversion needed
if _, isString := content.(string); isString {
continue
}
// If content is an array (AI SDK format), convert to string
contentArray, ok := content.([]any)
if !ok {
continue
}
// Extract text from content array
var textParts []string
for _, part := range contentArray {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
// Handle different content types
partType, _ := partMap["type"].(string)
switch partType {
case "input_text", "text":
// Extract text from input_text or text type
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
case "input_image", "image":
// For images, we need to preserve the original format
// as ChatGPT Codex API may support images in a different way
// For now, skip image parts (they will be lost in conversion)
// TODO: Consider preserving image data or handling it separately
continue
case "input_file", "file":
// Similar to images, file inputs may need special handling
continue
default:
// For unknown types, try to extract text if available
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
}
}
// Convert content array to string
if len(textParts) > 0 {
message["content"] = strings.Join(textParts, "\n")
modified = true
}
}
return modified
}
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult

View File

@@ -220,7 +220,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
Credentials: map[string]any{"base_url": "://invalid-url"},
}
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false)
if err == nil {
t.Fatalf("expected error for invalid base_url when allowlist disabled")
}

View File

@@ -24,10 +24,11 @@ var (
// PromoService 优惠码服务
type PromoService struct {
promoRepo PromoCodeRepository
userRepo UserRepository
billingCacheService *BillingCacheService
entClient *dbent.Client
promoRepo PromoCodeRepository
userRepo UserRepository
billingCacheService *BillingCacheService
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewPromoService 创建优惠码服务实例
@@ -36,12 +37,14 @@ func NewPromoService(
userRepo UserRepository,
billingCacheService *BillingCacheService,
entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) *PromoService {
return &PromoService{
promoRepo: promoRepo,
userRepo: userRepo,
billingCacheService: billingCacheService,
entClient: entClient,
promoRepo: promoRepo,
userRepo: userRepo,
billingCacheService: billingCacheService,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -145,6 +148,8 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st
return fmt.Errorf("commit transaction: %w", err)
}
s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount)
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
@@ -157,6 +162,13 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st
return nil
}
func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) {
if bonusAmount == 0 || s.authCacheInvalidator == nil {
return
}
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
// GenerateRandomCode 生成随机优惠码
func (s *PromoService) GenerateRandomCode() (string, error) {
bytes := make([]byte, 8)

View File

@@ -0,0 +1,122 @@
# Codex Running in OpenCode
You are running Codex through OpenCode, an open-source terminal coding assistant. OpenCode provides different tools but follows Codex operating principles.
## CRITICAL: Tool Replacements
<critical_rule priority="0">
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
- NEVER use: apply_patch, applyPatch
- ALWAYS use: edit tool for ALL file modifications
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
</critical_rule>
<critical_rule priority="0">
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
- NEVER use: update_plan, updatePlan, read_plan, readPlan
- ALWAYS use: todowrite for task/plan updates, todoread to read plans
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
</critical_rule>
## Available OpenCode Tools
**File Operations:**
- `write` - Create new files
- Overwriting existing files requires a prior Read in this session; default to ASCII unless the file already uses Unicode.
- `edit` - Modify existing files (REPLACES apply_patch)
- Requires a prior Read in this session; preserve exact indentation; ensure `oldString` uniquely matches or use `replaceAll`; edit fails if ambiguous or missing.
- `read` - Read file contents
**Search/Discovery:**
- `grep` - Search file contents (tool, not bash grep); use `include` to filter patterns; set `path` only when not searching workspace root; for cross-file match counts use bash with `rg`.
- `glob` - Find files by pattern; defaults to workspace cwd unless `path` is set.
- `list` - List directories (requires absolute paths)
**Execution:**
- `bash` - Run shell commands
- No workdir parameter; do not include it in tool calls.
- Always include a short description for the command.
- Do not use cd; use absolute paths in commands.
- Quote paths containing spaces with double quotes.
- Chain multiple commands with ';' or '&&'; avoid newlines.
- Use Grep/Glob tools for searches; only use bash with `rg` when you need counts or advanced features.
- Do not use `ls`/`cat` in bash; use `list`/`read` tools instead.
- For deletions (rm), verify by listing parent dir with `list`.
**Network:**
- `webfetch` - Fetch web content
- Use fully-formed URLs (http/https; http auto-upgrades to https).
- Always set `format` to one of: text | markdown | html; prefer markdown unless otherwise required.
- Read-only; short cache window.
**Task Management:**
- `todowrite` - Manage tasks/plans (REPLACES update_plan)
- `todoread` - Read current plan
## Substitution Rules
Base instruction says: You MUST use instead:
apply_patch → edit
update_plan → todowrite
read_plan → todoread
**Path Usage:** Use per-tool conventions to avoid conflicts:
- Tool calls: `read`, `edit`, `write`, `list` require absolute paths.
- Searches: `grep`/`glob` default to the workspace cwd; prefer relative include patterns; set `path` only when a different root is needed.
- Presentation: In assistant messages, show workspace-relative paths; use absolute paths only inside tool calls.
- Tool schema overrides general path preferences—do not convert required absolute paths to relative.
## Verification Checklist
Before file/plan modifications:
1. Am I using "edit" NOT "apply_patch"?
2. Am I using "todowrite" NOT "update_plan"?
3. Is this tool in the approved list above?
4. Am I following each tool's path requirements?
If ANY answer is NO → STOP and correct before proceeding.
## OpenCode Working Style
**Communication:**
- Send brief preambles (8-12 words) before tool calls, building on prior context
- Provide progress updates during longer tasks
**Execution:**
- Keep working autonomously until query is fully resolved before yielding
- Don't return to user with partial solutions
**Code Approach:**
- New projects: Be ambitious and creative
- Existing codebases: Surgical precision - modify only what's requested unless explicitly instructed to do otherwise
**Testing:**
- If tests exist: Start specific to your changes, then broader validation
## Advanced Tools
**Task Tool (Sub-Agents):**
- Use the Task tool (functions.task) to launch sub-agents
- Check the Task tool description for current agent types and their capabilities
- Useful for complex analysis, specialized workflows, or tasks requiring isolated context
- The agent list is dynamically generated - refer to tool schema for available agents
**Parallelization:**
- When multiple independent tool calls are needed, use multi_tool_use.parallel to run them concurrently.
- Reserve sequential calls for ordered or data-dependent steps.
**MCP Tools:**
- Model Context Protocol servers provide additional capabilities
- MCP tools are prefixed: `mcp__<server-name>__<tool-name>`
- Check your available tools for MCP integrations
- Use when the tool's functionality matches your task needs
## What Remains from Codex
Sandbox policies, approval mechanisms, final answer formatting, git commit protocols, and file reference formats all follow Codex instructions. In approval policy "never", never request escalations.
## Approvals & Safety
- Assume workspace-write filesystem, network enabled, approval on-failure unless explicitly stated otherwise.
- When a command fails due to sandboxing or permissions, retry with escalated permissions if allowed by policy, including a one-line justification.
- Treat destructive commands (e.g., `rm`, `git reset --hard`) as requiring explicit user request or approval.
- When uncertain, prefer non-destructive verification first (e.g., confirm file existence with `list`, then delete with `bash`).

View File

@@ -0,0 +1,63 @@
<user_instructions priority="0">
<environment_override priority="0">
YOU ARE IN A DIFFERENT ENVIRONMENT. These instructions override ALL previous tool references.
</environment_override>
<tool_replacements priority="0">
<critical_rule priority="0">
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
- NEVER use: apply_patch, applyPatch
- ALWAYS use: edit tool for ALL file modifications
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
</critical_rule>
<critical_rule priority="0">
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
- NEVER use: update_plan, updatePlan
- ALWAYS use: todowrite for ALL task/plan operations
- Use todoread to read current plan
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
</critical_rule>
</tool_replacements>
<available_tools priority="0">
File Operations:
• write - Create new files
• edit - Modify existing files (REPLACES apply_patch)
• patch - Apply diff patches
• read - Read file contents
Search/Discovery:
• grep - Search file contents
• glob - Find files by pattern
• list - List directories (use relative paths)
Execution:
• bash - Run shell commands
Network:
• webfetch - Fetch web content
Task Management:
• todowrite - Manage tasks/plans (REPLACES update_plan)
• todoread - Read current plan
</available_tools>
<substitution_rules priority="0">
Base instruction says: You MUST use instead:
apply_patch → edit
update_plan → todowrite
read_plan → todoread
absolute paths → relative paths
</substitution_rules>
<verification_checklist priority="0">
Before file/plan modifications:
1. Am I using "edit" NOT "apply_patch"?
2. Am I using "todowrite" NOT "update_plan"?
3. Is this tool in the approved list above?
4. Am I using relative paths?
If ANY answer is NO → STOP and correct before proceeding.
</verification_checklist>
</user_instructions>

View File

@@ -68,12 +68,13 @@ type RedeemCodeResponse struct {
// RedeemService 兑换码服务
type RedeemService struct {
redeemRepo RedeemCodeRepository
userRepo UserRepository
subscriptionService *SubscriptionService
cache RedeemCache
billingCacheService *BillingCacheService
entClient *dbent.Client
redeemRepo RedeemCodeRepository
userRepo UserRepository
subscriptionService *SubscriptionService
cache RedeemCache
billingCacheService *BillingCacheService
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewRedeemService 创建兑换码服务实例
@@ -84,14 +85,16 @@ func NewRedeemService(
cache RedeemCache,
billingCacheService *BillingCacheService,
entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
cache: cache,
billingCacheService: billingCacheService,
entClient: entClient,
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
cache: cache,
billingCacheService: billingCacheService,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -324,18 +327,33 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// invalidateRedeemCaches 失效兑换相关的缓存
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
if s.billingCacheService == nil {
return
}
switch redeemCode.Type {
case RedeemTypeBalance:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
case RedeemTypeConcurrency:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
case RedeemTypeSubscription:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
if redeemCode.GroupID != nil {
groupID := *redeemCode.GroupID
go func() {

View File

@@ -32,6 +32,8 @@ type SettingRepository interface {
type SettingService struct {
settingRepo SettingRepository
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
}
// NewSettingService 创建系统设置服务实例
@@ -65,6 +67,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyAPIBaseURL,
SettingKeyContactInfo,
SettingKeyDocURL,
SettingKeyHomeContent,
SettingKeyLinuxDoConnectEnabled,
}
@@ -91,10 +94,62 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent],
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil
}
// SetOnUpdateCallback sets a callback function to be called when settings are updated
// This is used for cache invalidation (e.g., HTML cache in frontend server)
func (s *SettingService) SetOnUpdateCallback(callback func()) {
s.onUpdate = callback
}
// SetVersion sets the application version for injection into public settings
func (s *SettingService) SetVersion(version string) {
s.version = version
}
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection
// This implements the web.PublicSettingsProvider interface
func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any, error) {
settings, err := s.GetPublicSettings(ctx)
if err != nil {
return nil, err
}
// Return a struct that matches the frontend's expected format
return &struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo,omitempty"`
SiteSubtitle string `json:"site_subtitle,omitempty"`
APIBaseURL string `json:"api_base_url,omitempty"`
ContactInfo string `json:"contact_info,omitempty"`
DocURL string `json:"doc_url,omitempty"`
HomeContent string `json:"home_content,omitempty"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
}, nil
}
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
updates := make(map[string]string)
@@ -136,6 +191,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocURL] = settings.DocURL
updates[SettingKeyHomeContent] = settings.HomeContent
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
@@ -152,7 +208,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch)
updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt
return s.settingRepo.SetMultiple(ctx, updates)
err := s.settingRepo.SetMultiple(ctx, updates)
if err == nil && s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update
}
return err
}
// IsRegistrationEnabled 检查是否开放注册
@@ -263,6 +323,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent],
}
// 解析整数类型

View File

@@ -31,6 +31,7 @@ type SystemSettings struct {
APIBaseURL string
ContactInfo string
DocURL string
HomeContent string
DefaultConcurrency int
DefaultBalance float64
@@ -58,6 +59,7 @@ type PublicSettings struct {
APIBaseURL string
ContactInfo string
DocURL string
HomeContent string
LinuxDoOAuthEnabled bool
Version string
}

View File

@@ -54,17 +54,19 @@ type UsageStats struct {
// UsageService 使用统计服务
type UsageService struct {
usageRepo UsageLogRepository
userRepo UserRepository
entClient *dbent.Client
usageRepo UsageLogRepository
userRepo UserRepository
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator) *UsageService {
return &UsageService{
usageRepo: usageRepo,
userRepo: userRepo,
entClient: entClient,
usageRepo: usageRepo,
userRepo: userRepo,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -118,10 +120,12 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
}
// 扣除用户余额
balanceUpdated := false
if inserted && req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
balanceUpdated = true
}
if tx != nil {
@@ -130,9 +134,18 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
}
}
s.invalidateUsageCaches(ctx, req.UserID, balanceUpdated)
return usageLog, nil
}
func (s *UsageService) invalidateUsageCaches(ctx context.Context, userID int64, balanceUpdated bool) {
if !balanceUpdated || s.authCacheInvalidator == nil {
return
}
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
// GetByID 根据ID获取使用日志
func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id)

View File

@@ -55,13 +55,15 @@ type ChangePasswordRequest struct {
// UserService 用户服务
type UserService struct {
userRepo UserRepository
userRepo UserRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo UserRepository) *UserService {
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
return &UserService{
userRepo: userRepo,
userRepo: userRepo,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -89,6 +91,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
// 更新字段
if req.Email != nil {
@@ -114,6 +117,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return user, nil
}
@@ -169,6 +175,9 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
return fmt.Errorf("update balance: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
@@ -177,6 +186,9 @@ func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concu
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
return fmt.Errorf("update concurrency: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
@@ -192,12 +204,18 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
// Delete 删除用户(管理员功能)
func (s *UserService) Delete(ctx context.Context, userID int64) error {
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete user: %w", err)
}

View File

@@ -77,12 +77,18 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
return svc
}
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
return apiKeyService
}
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewAPIKeyService,
ProvideAPIKeyAuthCacheInvalidator,
NewGroupService,
NewAccountService,
NewProxyService,

View File

@@ -4,11 +4,38 @@
package web
import (
"context"
"errors"
"net/http"
"github.com/gin-gonic/gin"
)
// PublicSettingsProvider is an interface to fetch public settings
// This stub is needed for compilation when frontend is not embedded
type PublicSettingsProvider interface {
GetPublicSettingsForInjection(ctx context.Context) (any, error)
}
// FrontendServer is a stub for non-embed builds
type FrontendServer struct{}
// NewFrontendServer returns an error when frontend is not embedded
func NewFrontendServer(settingsProvider PublicSettingsProvider) (*FrontendServer, error) {
return nil, errors.New("frontend not embedded")
}
// InvalidateCache is a no-op for non-embed builds
func (s *FrontendServer) InvalidateCache() {}
// Middleware returns a handler that returns 404 for non-embed builds
func (s *FrontendServer) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.String(http.StatusNotFound, "Frontend not embedded. Build with -tags embed to include frontend.")
c.Abort()
}
}
func ServeEmbeddedFrontend() gin.HandlerFunc {
return func(c *gin.Context) {
c.String(http.StatusNotFound, "Frontend not embedded. Build with -tags embed to include frontend.")

View File

@@ -3,11 +3,15 @@
package web
import (
"bytes"
"context"
"embed"
"encoding/json"
"io"
"io/fs"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
)
@@ -15,6 +19,162 @@ import (
//go:embed all:dist
var frontendFS embed.FS
// PublicSettingsProvider is an interface to fetch public settings
type PublicSettingsProvider interface {
GetPublicSettingsForInjection(ctx context.Context) (any, error)
}
// FrontendServer serves the embedded frontend with settings injection
type FrontendServer struct {
distFS fs.FS
fileServer http.Handler
baseHTML []byte
cache *HTMLCache
settings PublicSettingsProvider
}
// NewFrontendServer creates a new frontend server with settings injection
func NewFrontendServer(settingsProvider PublicSettingsProvider) (*FrontendServer, error) {
distFS, err := fs.Sub(frontendFS, "dist")
if err != nil {
return nil, err
}
// Read base HTML once
file, err := distFS.Open("index.html")
if err != nil {
return nil, err
}
defer func() { _ = file.Close() }()
baseHTML, err := io.ReadAll(file)
if err != nil {
return nil, err
}
cache := NewHTMLCache()
cache.SetBaseHTML(baseHTML)
return &FrontendServer{
distFS: distFS,
fileServer: http.FileServer(http.FS(distFS)),
baseHTML: baseHTML,
cache: cache,
settings: settingsProvider,
}, nil
}
// InvalidateCache invalidates the HTML cache (call when settings change)
func (s *FrontendServer) InvalidateCache() {
if s != nil && s.cache != nil {
s.cache.Invalidate()
}
}
// Middleware returns the Gin middleware handler
func (s *FrontendServer) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
// Skip API routes
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" ||
path == "/responses" {
c.Next()
return
}
cleanPath := strings.TrimPrefix(path, "/")
if cleanPath == "" {
cleanPath = "index.html"
}
// For index.html or SPA routes, serve with injected settings
if cleanPath == "index.html" || !s.fileExists(cleanPath) {
s.serveIndexHTML(c)
return
}
// Serve static files normally
s.fileServer.ServeHTTP(c.Writer, c.Request)
c.Abort()
}
}
func (s *FrontendServer) fileExists(path string) bool {
file, err := s.distFS.Open(path)
if err != nil {
return false
}
_ = file.Close()
return true
}
func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
// Check cache first
cached := s.cache.Get()
if cached != nil {
// Check If-None-Match for 304 response
if match := c.GetHeader("If-None-Match"); match == cached.ETag {
c.Status(http.StatusNotModified)
c.Abort()
return
}
c.Header("ETag", cached.ETag)
c.Header("Cache-Control", "no-cache") // Must revalidate
c.Data(http.StatusOK, "text/html; charset=utf-8", cached.Content)
c.Abort()
return
}
// Cache miss - fetch settings and render
ctx, cancel := context.WithTimeout(c.Request.Context(), 2*time.Second)
defer cancel()
settings, err := s.settings.GetPublicSettingsForInjection(ctx)
if err != nil {
// Fallback: serve without injection
c.Data(http.StatusOK, "text/html; charset=utf-8", s.baseHTML)
c.Abort()
return
}
settingsJSON, err := json.Marshal(settings)
if err != nil {
// Fallback: serve without injection
c.Data(http.StatusOK, "text/html; charset=utf-8", s.baseHTML)
c.Abort()
return
}
rendered := s.injectSettings(settingsJSON)
s.cache.Set(rendered, settingsJSON)
cached = s.cache.Get()
if cached != nil {
c.Header("ETag", cached.ETag)
}
c.Header("Cache-Control", "no-cache")
c.Data(http.StatusOK, "text/html; charset=utf-8", rendered)
c.Abort()
}
func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte {
// Create the script tag to inject
script := []byte(`<script>window.__APP_CONFIG__=` + string(settingsJSON) + `;</script>`)
// Inject before </head>
headClose := []byte("</head>")
return bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1)
}
// ServeEmbeddedFrontend returns a middleware for serving embedded frontend
// This is the legacy function for backward compatibility when no settings provider is available
func ServeEmbeddedFrontend() gin.HandlerFunc {
distFS, err := fs.Sub(frontendFS, "dist")
if err != nil {

View File

@@ -0,0 +1,77 @@
//go:build embed
package web
import (
"crypto/sha256"
"encoding/hex"
"sync"
)
// HTMLCache manages the cached index.html with injected settings
type HTMLCache struct {
mu sync.RWMutex
cachedHTML []byte
etag string
baseHTMLHash string // Hash of the original index.html (immutable after build)
settingsVersion uint64 // Incremented when settings change
}
// CachedHTML represents the cache state
type CachedHTML struct {
Content []byte
ETag string
}
// NewHTMLCache creates a new HTML cache instance
func NewHTMLCache() *HTMLCache {
return &HTMLCache{}
}
// SetBaseHTML initializes the cache with the base HTML template
func (c *HTMLCache) SetBaseHTML(baseHTML []byte) {
c.mu.Lock()
defer c.mu.Unlock()
hash := sha256.Sum256(baseHTML)
c.baseHTMLHash = hex.EncodeToString(hash[:8]) // First 8 bytes for brevity
}
// Invalidate marks the cache as stale
func (c *HTMLCache) Invalidate() {
c.mu.Lock()
defer c.mu.Unlock()
c.settingsVersion++
c.cachedHTML = nil
c.etag = ""
}
// Get returns the cached HTML or nil if cache is stale
func (c *HTMLCache) Get() *CachedHTML {
c.mu.RLock()
defer c.mu.RUnlock()
if c.cachedHTML == nil {
return nil
}
return &CachedHTML{
Content: c.cachedHTML,
ETag: c.etag,
}
}
// Set updates the cache with new rendered HTML
func (c *HTMLCache) Set(html []byte, settingsJSON []byte) {
c.mu.Lock()
defer c.mu.Unlock()
c.cachedHTML = html
c.etag = c.generateETag(settingsJSON)
}
// generateETag creates an ETag from base HTML hash + settings hash
func (c *HTMLCache) generateETag(settingsJSON []byte) string {
settingsHash := sha256.Sum256(settingsJSON)
return `"` + c.baseHTMLHash + "-" + hex.EncodeToString(settingsHash[:8]) + `"`
}

View File

@@ -170,6 +170,30 @@ gateway:
# 允许在特定 400 错误时进行故障转移(默认:关闭)
failover_on_400: false
# =============================================================================
# API Key Auth Cache Configuration
# API Key 认证缓存配置
# =============================================================================
api_key_auth_cache:
# L1 cache size (entries), in-process LRU/TTL cache
# L1 缓存容量(条目数),进程内 LRU/TTL 缓存
l1_size: 65535
# L1 cache TTL (seconds)
# L1 缓存 TTL
l1_ttl_seconds: 15
# L2 cache TTL (seconds), stored in Redis
# L2 缓存 TTLRedis 中存储
l2_ttl_seconds: 300
# Negative cache TTL (seconds)
# 负缓存 TTL
negative_ttl_seconds: 30
# TTL jitter percent (0-100)
# TTL 抖动百分比0-100
jitter_percent: 10
# Enable singleflight for cache misses
# 缓存未命中时启用 singleflight 合并回源
singleflight: true
# =============================================================================
# Concurrency Wait Configuration
# 并发等待配置

View File

@@ -170,6 +170,30 @@ gateway:
# 允许在特定 400 错误时进行故障转移(默认:关闭)
failover_on_400: false
# =============================================================================
# API Key Auth Cache Configuration
# API Key 认证缓存配置
# =============================================================================
api_key_auth_cache:
# L1 cache size (entries), in-process LRU/TTL cache
# L1 缓存容量(条目数),进程内 LRU/TTL 缓存
l1_size: 65535
# L1 cache TTL (seconds)
# L1 缓存 TTL
l1_ttl_seconds: 15
# L2 cache TTL (seconds), stored in Redis
# L2 缓存 TTLRedis 中存储
l2_ttl_seconds: 300
# Negative cache TTL (seconds)
# 负缓存 TTL
negative_ttl_seconds: 30
# TTL jitter percent (0-100)
# TTL 抖动百分比0-100
jitter_percent: 10
# Enable singleflight for cache misses
# 缓存未命中时启用 singleflight 合并回源
singleflight: true
# =============================================================================
# Concurrency Wait Configuration
# 并发等待配置

View File

@@ -22,6 +22,7 @@ export interface SystemSettings {
api_base_url: string
contact_info: string
doc_url: string
home_content: string
// SMTP settings
smtp_host: string
smtp_port: number
@@ -55,6 +56,7 @@ export interface UpdateSettingsRequest {
api_base_url?: string
contact_info?: string
doc_url?: string
home_content?: string
smtp_host?: string
smtp_port?: number
smtp_username?: string

View File

@@ -1900,7 +1900,11 @@ export default {
logoHint: 'PNG, JPG, or SVG. Max 300KB. Recommended: 80x80px square image.',
logoSizeError: 'Image size exceeds 300KB limit ({size}KB)',
logoTypeError: 'Please select an image file',
logoReadError: 'Failed to read the image file'
logoReadError: 'Failed to read the image file',
homeContent: 'Home Page Content',
homeContentPlaceholder: 'Enter custom content for the home page. Supports Markdown & HTML. If a URL is entered, it will be displayed as an iframe.',
homeContentHint: 'Customize the home page content. Supports Markdown/HTML. If you enter a URL (starting with http:// or https://), it will be used as an iframe src to embed an external page. When set, the default status information will no longer be displayed.',
homeContentIframeWarning: '⚠️ iframe mode note: Some websites have X-Frame-Options or CSP security policies that prevent embedding in iframes. If the page appears blank or shows an error, please verify the target website allows embedding, or consider using HTML mode to build your own content.'
},
smtp: {
title: 'SMTP Settings',

View File

@@ -2043,7 +2043,11 @@ export default {
logoHint: 'PNG、JPG 或 SVG 格式,最大 300KB。建议80x80px 正方形图片。',
logoSizeError: '图片大小超过 300KB 限制({size}KB',
logoTypeError: '请选择图片文件',
logoReadError: '读取图片文件失败'
logoReadError: '读取图片文件失败',
homeContent: '首页内容',
homeContentPlaceholder: '在此输入首页内容,支持 Markdown & HTML 代码。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性。',
homeContentHint: '自定义首页内容,支持 Markdown/HTML。如果输入的是链接以 http:// 或 https:// 开头),则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为首页。设置后首页的状态信息将不再显示。',
homeContentIframeWarning: '⚠️ iframe 模式提示:部分网站设置了 X-Frame-Options 或 CSP 安全策略,禁止被嵌入到 iframe 中。如果页面显示空白或报错,请确认目标网站允许被嵌入,或考虑使用 HTML 模式自行构建页面内容。'
},
smtp: {
title: 'SMTP 设置',

View File

@@ -6,7 +6,20 @@ import i18n from './i18n'
import './style.css'
const app = createApp(App)
app.use(createPinia())
const pinia = createPinia()
app.use(pinia)
// Initialize settings from injected config BEFORE mounting (prevents flash)
// This must happen after pinia is installed but before router and i18n
import { useAppStore } from '@/stores/app'
const appStore = useAppStore()
appStore.initFromInjectedConfig()
// Set document title immediately after config is loaded
if (appStore.siteName && appStore.siteName !== 'Sub2API') {
document.title = `${appStore.siteName} - AI API Gateway`
}
app.use(router)
app.use(i18n)

View File

@@ -5,6 +5,7 @@
import { createRouter, createWebHistory, type RouteRecordRaw } from 'vue-router'
import { useAuthStore } from '@/stores/auth'
import { useAppStore } from '@/stores/app'
/**
* Route definitions with lazy loading
@@ -323,10 +324,12 @@ router.beforeEach((to, _from, next) => {
}
// Set page title
const appStore = useAppStore()
const siteName = appStore.siteName || 'Sub2API'
if (to.meta.title) {
document.title = `${to.meta.title} - Sub2API`
document.title = `${to.meta.title} - ${siteName}`
} else {
document.title = 'Sub2API'
document.title = siteName
}
// Check if route requires authentication

View File

@@ -279,11 +279,31 @@ export const useAppStore = defineStore('app', () => {
// ==================== Public Settings Management ====================
/**
* Apply settings to store state (internal helper to avoid code duplication)
*/
function applySettings(config: PublicSettings): void {
cachedPublicSettings.value = config
siteName.value = config.site_name || 'Sub2API'
siteLogo.value = config.site_logo || ''
siteVersion.value = config.version || ''
contactInfo.value = config.contact_info || ''
apiBaseUrl.value = config.api_base_url || ''
docUrl.value = config.doc_url || ''
publicSettingsLoaded.value = true
}
/**
* Fetch public settings (uses cache unless force=true)
* @param force - Force refresh from API
*/
async function fetchPublicSettings(force = false): Promise<PublicSettings | null> {
// Check for injected config from server (eliminates flash)
if (!publicSettingsLoaded.value && !force && window.__APP_CONFIG__) {
applySettings(window.__APP_CONFIG__)
return window.__APP_CONFIG__
}
// Return cached data if available and not forcing refresh
if (publicSettingsLoaded.value && !force) {
if (cachedPublicSettings.value) {
@@ -300,6 +320,7 @@ export const useAppStore = defineStore('app', () => {
api_base_url: apiBaseUrl.value,
contact_info: contactInfo.value,
doc_url: docUrl.value,
home_content: '',
linuxdo_oauth_enabled: false,
version: siteVersion.value
}
@@ -313,14 +334,7 @@ export const useAppStore = defineStore('app', () => {
publicSettingsLoading.value = true
try {
const data = await fetchPublicSettingsAPI()
cachedPublicSettings.value = data
siteName.value = data.site_name || 'Sub2API'
siteLogo.value = data.site_logo || ''
siteVersion.value = data.version || ''
contactInfo.value = data.contact_info || ''
apiBaseUrl.value = data.api_base_url || ''
docUrl.value = data.doc_url || ''
publicSettingsLoaded.value = true
applySettings(data)
return data
} catch (error) {
console.error('Failed to fetch public settings:', error)
@@ -338,6 +352,19 @@ export const useAppStore = defineStore('app', () => {
cachedPublicSettings.value = null
}
/**
* Initialize settings from injected config (window.__APP_CONFIG__)
* This is called synchronously before Vue app mounts to prevent flash
* @returns true if config was found and applied, false otherwise
*/
function initFromInjectedConfig(): boolean {
if (window.__APP_CONFIG__) {
applySettings(window.__APP_CONFIG__)
return true
}
return false
}
// ==================== Return Store API ====================
return {
@@ -355,6 +382,7 @@ export const useAppStore = defineStore('app', () => {
contactInfo,
apiBaseUrl,
docUrl,
cachedPublicSettings,
// Version state
versionLoaded,
@@ -391,6 +419,7 @@ export const useAppStore = defineStore('app', () => {
// Public settings actions
fetchPublicSettings,
clearPublicSettingsCache
clearPublicSettingsCache,
initFromInjectedConfig
}
})

9
frontend/src/types/global.d.ts vendored Normal file
View File

@@ -0,0 +1,9 @@
import type { PublicSettings } from '@/types'
declare global {
interface Window {
__APP_CONFIG__?: PublicSettings
}
}
export {}

View File

@@ -74,6 +74,7 @@ export interface PublicSettings {
api_base_url: string
contact_info: string
doc_url: string
home_content: string
linuxdo_oauth_enabled: boolean
version: string
}

View File

@@ -1,6 +1,21 @@
<template>
<!-- Custom Home Content: Full Page Mode -->
<div v-if="homeContent" class="min-h-screen">
<!-- iframe mode -->
<iframe
v-if="isHomeContentUrl"
:src="homeContent.trim()"
class="h-screen w-full border-0"
allowfullscreen
></iframe>
<!-- HTML mode - SECURITY: homeContent is admin-only setting, XSS risk is acceptable -->
<div v-else v-html="homeContent"></div>
</div>
<!-- Default Home Page -->
<div
class="relative min-h-screen overflow-hidden bg-gradient-to-br from-gray-50 via-primary-50/30 to-gray-100 dark:from-dark-950 dark:via-dark-900 dark:to-dark-950"
v-else
class="relative flex min-h-screen flex-col overflow-hidden bg-gradient-to-br from-gray-50 via-primary-50/30 to-gray-100 dark:from-dark-950 dark:via-dark-900 dark:to-dark-950"
>
<!-- Background Decorations -->
<div class="pointer-events-none absolute inset-0 overflow-hidden">
@@ -96,7 +111,7 @@
</header>
<!-- Main Content -->
<main class="relative z-10 px-6 py-16">
<main class="relative z-10 flex-1 px-6 py-16">
<div class="mx-auto max-w-6xl">
<!-- Hero Section - Left/Right Layout -->
<div class="mb-12 flex flex-col items-center justify-between gap-12 lg:flex-row lg:gap-16">
@@ -392,21 +407,27 @@
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { getPublicSettings } from '@/api/auth'
import { useAuthStore } from '@/stores'
import { useAuthStore, useAppStore } from '@/stores'
import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue'
import Icon from '@/components/icons/Icon.vue'
import { sanitizeUrl } from '@/utils/url'
const { t } = useI18n()
const authStore = useAuthStore()
const appStore = useAppStore()
// Site settings
const siteName = ref('Sub2API')
const siteLogo = ref('')
const siteSubtitle = ref('AI API Gateway Platform')
const docUrl = ref('')
// Site settings - directly from appStore (already initialized from injected config)
const siteName = computed(() => appStore.cachedPublicSettings?.site_name || appStore.siteName || 'Sub2API')
const siteLogo = computed(() => appStore.cachedPublicSettings?.site_logo || appStore.siteLogo || '')
const siteSubtitle = computed(() => appStore.cachedPublicSettings?.site_subtitle || 'AI API Gateway Platform')
const docUrl = computed(() => appStore.cachedPublicSettings?.doc_url || appStore.docUrl || '')
const homeContent = computed(() => appStore.cachedPublicSettings?.home_content || '')
// Check if homeContent is a URL (for iframe display)
const isHomeContentUrl = computed(() => {
const content = homeContent.value.trim()
return content.startsWith('http://') || content.startsWith('https://')
})
// Theme
const isDark = ref(document.documentElement.classList.contains('dark'))
@@ -446,20 +467,15 @@ function initTheme() {
}
}
onMounted(async () => {
onMounted(() => {
initTheme()
// Check auth state
authStore.checkAuth()
try {
const settings = await getPublicSettings()
siteName.value = settings.site_name || 'Sub2API'
siteLogo.value = sanitizeUrl(settings.site_logo || '', { allowRelative: true })
siteSubtitle.value = settings.site_subtitle || 'AI API Gateway Platform'
docUrl.value = sanitizeUrl(settings.doc_url || '', { allowRelative: true })
} catch (error) {
console.error('Failed to load public settings:', error)
// Ensure public settings are loaded (will use cache if already loaded from injected config)
if (!appStore.publicSettingsLoaded) {
appStore.fetchPublicSettings()
}
})
</script>

View File

@@ -562,6 +562,26 @@
</div>
</div>
</div>
<!-- Home Content -->
<div>
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.settings.site.homeContent') }}
</label>
<textarea
v-model="form.home_content"
rows="6"
class="input font-mono text-sm"
:placeholder="t('admin.settings.site.homeContentPlaceholder')"
></textarea>
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.settings.site.homeContentHint') }}
</p>
<!-- iframe CSP Warning -->
<p class="mt-2 text-xs text-amber-600 dark:text-amber-400">
{{ t('admin.settings.site.homeContentIframeWarning') }}
</p>
</div>
</div>
</div>
@@ -837,6 +857,7 @@ const form = reactive<SettingsForm>({
api_base_url: '',
contact_info: '',
doc_url: '',
home_content: '',
smtp_host: '',
smtp_port: 587,
smtp_username: '',
@@ -945,6 +966,7 @@ async function saveSettings() {
api_base_url: form.api_base_url,
contact_info: form.contact_info,
doc_url: form.doc_url,
home_content: form.home_content,
smtp_host: form.smtp_host,
smtp_port: form.smtp_port,
smtp_username: form.smtp_username,

View File

@@ -1,8 +1,39 @@
import { defineConfig } from 'vite'
import { defineConfig, Plugin } from 'vite'
import vue from '@vitejs/plugin-vue'
import checker from 'vite-plugin-checker'
import { resolve } from 'path'
/**
* Vite 插件:开发模式下注入公开配置到 index.html
* 与生产模式的后端注入行为保持一致,消除闪烁
*/
function injectPublicSettings(): Plugin {
const backendUrl = process.env.VITE_DEV_PROXY_TARGET || 'http://localhost:8080'
return {
name: 'inject-public-settings',
transformIndexHtml: {
order: 'pre',
async handler(html) {
try {
const response = await fetch(`${backendUrl}/api/v1/settings/public`, {
signal: AbortSignal.timeout(2000)
})
if (response.ok) {
const data = await response.json()
if (data.code === 0 && data.data) {
const script = `<script>window.__APP_CONFIG__=${JSON.stringify(data.data)};</script>`
return html.replace('</head>', `${script}\n</head>`)
}
}
} catch (e) {
console.warn('[vite] 无法获取公开配置,将回退到 API 调用:', (e as Error).message)
}
return html
}
}
}
}
export default defineConfig({
plugins: [
@@ -10,7 +41,8 @@ export default defineConfig({
checker({
typescript: true,
vueTsc: true
})
}),
injectPublicSettings()
],
resolve: {
alias: {