Merge PR #142: feat: 多平台网关优化与用户系统增强

主要功能:
- Gemini OAuth 配额系统优化:Google One tier 自动推断
- Antigravity 网关增强:Thinking Block 重试、Claude 模型 signature 透传
- 账号调度改进:临时不可调度功能、负载感知调度优化
- 前端用户体验:账号管理界面优化、使用教程改进

冲突解决:保留 handleUpstreamError 的 prefix 参数和日志记录能力,
同时合并 PR 的 thinking block 重试和 model fallback 功能。
This commit is contained in:
shaw
2026-01-04 20:30:19 +08:00
183 changed files with 8223 additions and 3839 deletions

View File

@@ -1,3 +1,4 @@
// Package config provides configuration loading, defaults, and validation.
package config
import (
@@ -139,7 +140,7 @@ type GatewayConfig struct {
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
@@ -241,7 +242,7 @@ type DefaultConfig struct {
AdminPassword string `mapstructure:"admin_password"`
UserConcurrency int `mapstructure:"user_concurrency"`
UserBalance float64 `mapstructure:"user_balance"`
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
APIKeyPrefix string `mapstructure:"api_key_prefix"`
RateMultiplier float64 `mapstructure:"rate_multiplier"`
}

View File

@@ -1,9 +1,12 @@
// Package admin provides HTTP handlers for administrative operations.
package admin
import (
"errors"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
@@ -69,42 +72,45 @@ func NewAccountHandler(
// CreateAccountRequest represents create account request
type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// UpdateAccountRequest represents update account request
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateAccountRequest struct {
Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// AccountWithConcurrency extends Account with real-time concurrency info
@@ -179,18 +185,40 @@ func (h *AccountHandler) Create(c *gin.Context) {
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
Name: req.Name,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
response.ErrorFrom(c, err)
return
}
@@ -213,18 +241,40 @@ func (h *AccountHandler) Update(c *gin.Context) {
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Name: req.Name,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
Status: req.Status,
GroupIDs: req.GroupIDs,
Name: req.Name,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
Status: req.Status,
GroupIDs: req.GroupIDs,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
response.ErrorFrom(c, err)
return
}
@@ -568,6 +618,9 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
hasUpdates := req.Name != "" ||
req.ProxyID != nil ||
req.Concurrency != nil ||
@@ -583,15 +636,16 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
}
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
AccountIDs: req.AccountIDs,
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -781,6 +835,49 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
}
// GetTempUnschedulable handles getting temporary unschedulable status
// GET /api/v1/admin/accounts/:id/temp-unschedulable
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
state, err := h.rateLimitService.GetTempUnschedStatus(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if state == nil || state.UntilUnix <= time.Now().Unix() {
response.Success(c, gin.H{"active": false})
return
}
response.Success(c, gin.H{
"active": true,
"state": state,
})
}
// ClearTempUnschedulable handles clearing temporary unschedulable status
// DELETE /api/v1/admin/accounts/:id/temp-unschedulable
func (h *AccountHandler) ClearTempUnschedulable(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
if err := h.rateLimitService.ClearTempUnschedulable(c.Request.Context(), accountID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Temp unschedulable cleared successfully"})
}
// GetTodayStats handles getting account today statistics
// GET /api/v1/admin/accounts/:id/today-stats
func (h *AccountHandler) GetTodayStats(c *gin.Context) {

View File

@@ -75,8 +75,8 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
"active_users": stats.ActiveUsers,
// API Key 统计
"total_api_keys": stats.TotalApiKeys,
"active_api_keys": stats.ActiveApiKeys,
"total_api_keys": stats.TotalAPIKeys,
"active_api_keys": stats.ActiveAPIKeys,
// 账户统计
"total_accounts": stats.TotalAccounts,
@@ -193,10 +193,10 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
})
}
// GetApiKeyUsageTrend handles getting API key usage trend data
// GetAPIKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "5")
@@ -205,7 +205,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
limit = 5
}
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
@@ -273,26 +273,26 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
response.Success(c, gin.H{"stats": stats})
}
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
type BatchAPIKeysUsageRequest struct {
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
// GetBatchAPIKeysUsage handles getting usage stats for multiple API keys
// POST /api/v1/admin/dashboard/api-keys-usage
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
var req BatchApiKeysUsageRequest
func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
var req BatchAPIKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.ApiKeyIDs) == 0 {
if len(req.APIKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return

View File

@@ -18,6 +18,7 @@ func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *Gemi
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
}
// GetCapabilities returns the Gemini OAuth configuration capabilities.
// GET /api/v1/admin/gemini/oauth/capabilities
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
cfg := h.geminiOAuthService.GetOAuthConfig()
@@ -30,6 +31,8 @@ type GeminiGenerateAuthURLRequest struct {
// OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id)
// 默认为 "code_assist" 以保持向后兼容
OAuthType string `json:"oauth_type"`
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
TierID string `json:"tier_id"`
}
// GenerateAuthURL generates Google OAuth authorization URL for Gemini.
@@ -54,7 +57,7 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
// Always pass the "hosted" callback URI; the OAuth service may override it depending on
// oauth_type and whether the built-in Gemini CLI OAuth client is used.
redirectURI := deriveGeminiRedirectURI(c)
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType)
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType, req.TierID)
if err != nil {
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
@@ -76,6 +79,9 @@ type GeminiExchangeCodeRequest struct {
ProxyID *int64 `json:"proxy_id"`
// OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致
OAuthType string `json:"oauth_type"`
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
// This field is optional; when omitted, the server uses the tier stored in the OAuth session.
TierID string `json:"tier_id"`
}
// ExchangeCode exchanges authorization code for tokens.
@@ -103,6 +109,7 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
Code: req.Code,
ProxyID: req.ProxyID,
OAuthType: oauthType,
TierID: req.TierID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())

View File

@@ -237,9 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
return
}
outKeys := make([]dto.ApiKey, 0, len(keys))
outKeys := make([]dto.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
outKeys = append(outKeys, *dto.APIKeyFromService(&keys[i]))
}
response.Paginated(c, outKeys, total, page, pageSize)
}

View File

@@ -34,26 +34,31 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SmtpHost: settings.SmtpHost,
SmtpPort: settings.SmtpPort,
SmtpUsername: settings.SmtpUsername,
SmtpPassword: settings.SmtpPassword,
SmtpFrom: settings.SmtpFrom,
SmtpFromName: settings.SmtpFromName,
SmtpUseTLS: settings.SmtpUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPassword: settings.SMTPPassword,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
})
}
@@ -64,13 +69,13 @@ type UpdateSettingsRequest struct {
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
@@ -81,13 +86,20 @@ type UpdateSettingsRequest struct {
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DocURL string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
}
// UpdateSettings 更新系统设置
@@ -106,8 +118,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
// Turnstile 参数验证
@@ -141,26 +153,31 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SmtpHost: req.SmtpHost,
SmtpPort: req.SmtpPort,
SmtpUsername: req.SmtpUsername,
SmtpPassword: req.SmtpPassword,
SmtpFrom: req.SmtpFrom,
SmtpFromName: req.SmtpFromName,
SmtpUseTLS: req.SmtpUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
ApiBaseUrl: req.ApiBaseUrl,
ContactInfo: req.ContactInfo,
DocUrl: req.DocUrl,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@@ -176,69 +193,74 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SmtpHost: updatedSettings.SmtpHost,
SmtpPort: updatedSettings.SmtpPort,
SmtpUsername: updatedSettings.SmtpUsername,
SmtpPassword: updatedSettings.SmtpPassword,
SmtpFrom: updatedSettings.SmtpFrom,
SmtpFromName: updatedSettings.SmtpFromName,
SmtpUseTLS: updatedSettings.SmtpUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
ApiBaseUrl: updatedSettings.ApiBaseUrl,
ContactInfo: updatedSettings.ContactInfo,
DocUrl: updatedSettings.DocUrl,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPassword: updatedSettings.SMTPPassword,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
})
}
// TestSmtpRequest 测试SMTP连接请求
type TestSmtpRequest struct {
SmtpHost string `json:"smtp_host" binding:"required"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
SMTPUseTLS bool `json:"smtp_use_tls"`
}
// TestSmtpConnection 测试SMTP连接
// TestSMTPConnection 测试SMTP连接
// POST /api/v1/admin/settings/test-smtp
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
var req TestSmtpRequest
func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
var req TestSMTPRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SmtpPassword
password := req.SMTPPassword
if password == "" {
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
}
}
config := &service.SmtpConfig{
Host: req.SmtpHost,
Port: req.SmtpPort,
Username: req.SmtpUsername,
config := &service.SMTPConfig{
Host: req.SMTPHost,
Port: req.SMTPPort,
Username: req.SMTPUsername,
Password: password,
UseTLS: req.SmtpUseTLS,
UseTLS: req.SMTPUseTLS,
}
err := h.emailService.TestSmtpConnectionWithConfig(config)
err := h.emailService.TestSMTPConnectionWithConfig(config)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -250,13 +272,13 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
// SendTestEmailRequest 发送测试邮件请求
type SendTestEmailRequest struct {
Email string `json:"email" binding:"required,email"`
SmtpHost string `json:"smtp_host" binding:"required"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
SMTPHost string `json:"smtp_host" binding:"required"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
}
// SendTestEmail 发送测试邮件
@@ -268,27 +290,27 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
return
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SmtpPassword
password := req.SMTPPassword
if password == "" {
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
}
}
config := &service.SmtpConfig{
Host: req.SmtpHost,
Port: req.SmtpPort,
Username: req.SmtpUsername,
config := &service.SMTPConfig{
Host: req.SMTPHost,
Port: req.SMTPPort,
Username: req.SMTPUsername,
Password: password,
From: req.SmtpFrom,
FromName: req.SmtpFromName,
UseTLS: req.SmtpUseTLS,
From: req.SMTPFrom,
FromName: req.SMTPFromName,
UseTLS: req.SMTPUseTLS,
}
siteName := h.settingService.GetSiteName(c.Request.Context())
@@ -333,10 +355,10 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
response.Success(c, gin.H{"message": "Test email sent successfully"})
}
// GetAdminApiKey 获取管理员 API Key 状态
// GetAdminAPIKey 获取管理员 API Key 状态
// GET /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
func (h *SettingHandler) GetAdminAPIKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminAPIKeyStatus(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
@@ -348,10 +370,10 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
})
}
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
// RegenerateAdminAPIKey 生成/重新生成管理员 API Key
// POST /api/v1/admin/settings/admin-api-key/regenerate
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
func (h *SettingHandler) RegenerateAdminAPIKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminAPIKey(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
@@ -362,10 +384,10 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
})
}
// DeleteAdminApiKey 删除管理员 API Key
// DeleteAdminAPIKey 删除管理员 API Key
// DELETE /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
if err := h.settingService.DeleteAdminAPIKey(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}

View File

@@ -17,14 +17,14 @@ import (
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageService *service.UsageService
apiKeyService *service.ApiKeyService
apiKeyService *service.APIKeyService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
func NewUsageHandler(
usageService *service.UsageService,
apiKeyService *service.ApiKeyService,
apiKeyService *service.APIKeyService,
adminService service.AdminService,
) *UsageHandler {
return &UsageHandler{
@@ -125,7 +125,7 @@ func (h *UsageHandler) List(c *gin.Context) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := usagestats.UsageLogFilters{
UserID: userID,
ApiKeyID: apiKeyID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
Model: model,
@@ -207,7 +207,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
stats, err := h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -269,9 +269,9 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
response.Success(c, result)
}
// SearchApiKeys handles searching API keys by user
// SearchAPIKeys handles searching API keys by user
// GET /api/v1/admin/usage/search-api-keys
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
userIDStr := c.Query("user_id")
keyword := c.Query("q")
@@ -285,22 +285,22 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
userID = id
}
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
keys, err := h.apiKeyService.SearchAPIKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return simplified API key list (only id and name)
type SimpleApiKey struct {
type SimpleAPIKey struct {
ID int64 `json:"id"`
Name string `json:"name"`
UserID int64 `json:"user_id"`
}
result := make([]SimpleApiKey, len(keys))
result := make([]SimpleAPIKey, len(keys))
for i, k := range keys {
result[i] = SimpleApiKey{
result[i] = SimpleAPIKey{
ID: k.ID,
Name: k.Name,
UserID: k.UserID,

View File

@@ -243,9 +243,9 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
return
}
out := make([]dto.ApiKey, 0, len(keys))
out := make([]dto.APIKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
out = append(out, *dto.APIKeyFromService(&keys[i]))
}
response.Paginated(c, out, total, page, pageSize)
}

View File

@@ -1,3 +1,4 @@
// Package handler provides HTTP request handlers for the application.
package handler
import (
@@ -14,11 +15,11 @@ import (
// APIKeyHandler handles API key-related requests
type APIKeyHandler struct {
apiKeyService *service.ApiKeyService
apiKeyService *service.APIKeyService
}
// NewAPIKeyHandler creates a new APIKeyHandler
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
return &APIKeyHandler{
apiKeyService: apiKeyService,
}
@@ -56,9 +57,9 @@ func (h *APIKeyHandler) List(c *gin.Context) {
return
}
out := make([]dto.ApiKey, 0, len(keys))
out := make([]dto.APIKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
out = append(out, *dto.APIKeyFromService(&keys[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
@@ -90,7 +91,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.ApiKeyFromService(key))
response.Success(c, dto.APIKeyFromService(key))
}
// Create handles creating a new API key
@@ -108,7 +109,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
return
}
svcReq := service.CreateApiKeyRequest{
svcReq := service.CreateAPIKeyRequest{
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
@@ -119,7 +120,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
return
}
response.Success(c, dto.ApiKeyFromService(key))
response.Success(c, dto.APIKeyFromService(key))
}
// Update handles updating an API key
@@ -143,7 +144,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return
}
svcReq := service.UpdateApiKeyRequest{}
svcReq := service.UpdateAPIKeyRequest{}
if req.Name != "" {
svcReq.Name = &req.Name
}
@@ -158,7 +159,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.ApiKeyFromService(key))
response.Success(c, dto.APIKeyFromService(key))
}
// Delete handles deleting an API key

View File

@@ -1,3 +1,4 @@
// Package dto provides data transfer objects for HTTP handlers.
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
@@ -26,11 +27,11 @@ func UserFromService(u *service.User) *User {
return nil
}
out := UserFromServiceShallow(u)
if len(u.ApiKeys) > 0 {
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
for i := range u.ApiKeys {
k := u.ApiKeys[i]
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
if len(u.APIKeys) > 0 {
out.APIKeys = make([]APIKey, 0, len(u.APIKeys))
for i := range u.APIKeys {
k := u.APIKeys[i]
out.APIKeys = append(out.APIKeys, *APIKeyFromService(&k))
}
}
if len(u.Subscriptions) > 0 {
@@ -43,11 +44,11 @@ func UserFromService(u *service.User) *User {
return out
}
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
func APIKeyFromService(k *service.APIKey) *APIKey {
if k == nil {
return nil
}
return &ApiKey{
return &APIKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
@@ -103,28 +104,30 @@ func AccountFromServiceShallow(a *service.Account) *Account {
return nil
}
return &Account{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
TempUnschedulableUntil: a.TempUnschedulableUntil,
TempUnschedulableReason: a.TempUnschedulableReason,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
}
@@ -220,7 +223,7 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
return &UsageLog{
ID: l.ID,
UserID: l.UserID,
ApiKeyID: l.ApiKeyID,
APIKeyID: l.APIKeyID,
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
@@ -245,7 +248,7 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
FirstTokenMs: l.FirstTokenMs,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
ApiKey: ApiKeyFromService(l.ApiKey),
APIKey: APIKeyFromService(l.APIKey),
Account: AccountFromService(l.Account),
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),

View File

@@ -5,13 +5,13 @@ type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password,omitempty"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
@@ -20,12 +20,19 @@ type SystemSettings struct {
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DocURL string `json:"doc_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
}
type PublicSettings struct {
@@ -36,8 +43,8 @@ type PublicSettings struct {
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DocURL string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -15,11 +15,11 @@ type User struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ApiKeys []ApiKey `json:"api_keys,omitempty"`
APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
type ApiKey struct {
type APIKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
@@ -76,6 +76,9 @@ type Account struct {
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
@@ -136,7 +139,7 @@ type RedeemCode struct {
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
@@ -168,7 +171,7 @@ type UsageLog struct {
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
ApiKey *ApiKey `json:"api_key,omitempty"`
APIKey *APIKey `json:"api_key,omitempty"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`

View File

@@ -53,7 +53,7 @@ func NewGatewayHandler(
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware2.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -259,7 +259,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
@@ -383,7 +383,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
@@ -400,7 +400,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Returns models based on account configurations (model_mapping whitelist)
// Falls back to default models if no whitelist is configured
func (h *GatewayHandler) Models(c *gin.Context) {
apiKey, _ := middleware2.GetApiKeyFromContext(c)
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
var groupID *int64
var platform string
@@ -458,7 +458,7 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
// Usage handles getting account balance for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -628,7 +628,7 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
// 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware2.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return

View File

@@ -21,7 +21,7 @@ import (
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware.GetAPIKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
@@ -67,7 +67,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
// GeminiV1BetaGetModel proxies:
// GET /v1beta/models/{model}
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware.GetAPIKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
@@ -120,7 +120,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
// POST /v1beta/models/{model}:generateContent
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware.GetAPIKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
@@ -299,7 +299,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,

View File

@@ -41,7 +41,7 @@ func NewOpenAIGatewayHandler(
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -235,7 +235,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
ApiKey: apiKey,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,

View File

@@ -39,9 +39,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
DocURL: settings.DocURL,
Version: h.version,
})
}

View File

@@ -18,11 +18,11 @@ import (
// UsageHandler handles usage-related requests
type UsageHandler struct {
usageService *service.UsageService
apiKeyService *service.ApiKeyService
apiKeyService *service.APIKeyService
}
// NewUsageHandler creates a new UsageHandler
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.APIKeyService) *UsageHandler {
return &UsageHandler{
usageService: usageService,
apiKeyService: apiKeyService,
@@ -111,7 +111,7 @@ func (h *UsageHandler) List(c *gin.Context) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := usagestats.UsageLogFilters{
UserID: subject.UserID, // Always filter by current user for security
ApiKeyID: apiKeyID,
APIKeyID: apiKeyID,
Model: model,
Stream: stream,
BillingType: billingType,
@@ -235,7 +235,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
var stats *service.UsageStats
var err error
if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
stats, err = h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
}
@@ -346,49 +346,49 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
})
}
// BatchApiKeysUsageRequest represents the request for batch API keys usage
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
// BatchAPIKeysUsageRequest represents the request for batch API keys usage
type BatchAPIKeysUsageRequest struct {
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
// DashboardAPIKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req BatchApiKeysUsageRequest
var req BatchAPIKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.ApiKeyIDs) == 0 {
if len(req.APIKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
// Limit the number of API key IDs to prevent SQL parameter overflow
if len(req.ApiKeyIDs) > 100 {
if len(req.APIKeyIDs) > 100 {
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
return
}
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
validAPIKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.APIKeyIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
if len(validApiKeyIDs) == 0 {
if len(validAPIKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -1,3 +1,4 @@
// Package antigravity provides a client for the Antigravity API.
package antigravity
import (
@@ -57,6 +58,29 @@ type TierInfo struct {
Description string `json:"description"` // 描述
}
// UnmarshalJSON supports both legacy string tiers and object tiers.
func (t *TierInfo) UnmarshalJSON(data []byte) error {
data = bytes.TrimSpace(data)
if len(data) == 0 || string(data) == "null" {
return nil
}
if data[0] == '"' {
var id string
if err := json.Unmarshal(data, &id); err != nil {
return err
}
t.ID = id
return nil
}
type alias TierInfo
var decoded alias
if err := json.Unmarshal(data, &decoded); err != nil {
return err
}
*t = TierInfo(decoded)
return nil
}
// IneligibleTier 不符合条件的层级信息
type IneligibleTier struct {
Tier *TierInfo `json:"tier,omitempty"`

View File

@@ -240,10 +240,13 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID: block.ID,
},
}
// 只有 Gemini 模型使用 dummy signature
// Claude 模型不设置 signature避免验证问题
// tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature跳过 thought_signature 校验
// - Claude 模型:透传上游返回的真实 signatureVertex/Google 需要完整签名链路)
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
part.ThoughtSignature = block.Signature
}
parts = append(parts, part)

View File

@@ -15,26 +15,26 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
description string
}{
{
name: "Claude model - skip thinking block without signature",
name: "Claude model - drop thinking without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
expectedParts: 2, // thinking 内容被丢弃
description: "Claude模型应丢弃无signature的thinking block内容",
},
{
name: "Claude model - keep thinking block with signature",
name: "Claude model - preserve thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "thinking", "thinking": "Let me think...", "signature": "sig_real_123"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signaturethinking block",
expectedParts: 3,
description: "Claude模型应透传带 signaturethinking block(用于 Vertex 签名链路)",
},
{
name: "Gemini model - use dummy signature",
@@ -61,10 +61,64 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
switch tt.name {
case "Claude model - preserve thinking block with signature":
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts))
}
if !parts[1].Thought || parts[1].ThoughtSignature != "sig_real_123" {
t.Fatalf("expected thought part with signature sig_real_123, got thought=%v signature=%q",
parts[1].Thought, parts[1].ThoughtSignature)
}
case "Gemini model - use dummy signature":
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts))
}
if !parts[1].Thought || parts[1].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy thought signature, got thought=%v signature=%q",
parts[1].Thought, parts[1].ThoughtSignature)
}
}
})
}
}
func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
content := `[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]`
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
}
})
t.Run("Claude model - preserve valid signature for tool_use", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(content), toolIDToName, false)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
// Claude 模型应透传有效的 signatureVertex/Google 需要完整签名链路)
if parts[0].ThoughtSignature != "sig_tool_abc" {
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
}
})
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {

View File

@@ -1,3 +1,4 @@
// Package claude provides constants and helpers for Claude API integration.
package claude
// Claude Code 客户端相关常量
@@ -16,13 +17,13 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头
// DefaultHeaders 是 Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
"X-Stainless-Lang": "js",

View File

@@ -1,3 +1,4 @@
// Package errors provides application error types and helpers.
// nolint:mnd
package errors

View File

@@ -1,7 +1,6 @@
package gemini
// This package provides minimal fallback model metadata for Gemini native endpoints.
// Package gemini provides minimal fallback model metadata for Gemini native endpoints.
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
package gemini
type Model struct {
Name string `json:"name"`

View File

@@ -1,5 +1,10 @@
package geminicli
import (
"bytes"
"encoding/json"
)
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
type LoadCodeAssistRequest struct {
Metadata LoadCodeAssistMetadata `json:"metadata"`
@@ -11,12 +16,51 @@ type LoadCodeAssistMetadata struct {
PluginType string `json:"pluginType"`
}
type TierInfo struct {
ID string `json:"id"`
}
// UnmarshalJSON supports both legacy string tiers and object tiers.
func (t *TierInfo) UnmarshalJSON(data []byte) error {
data = bytes.TrimSpace(data)
if len(data) == 0 || string(data) == "null" {
return nil
}
if data[0] == '"' {
var id string
if err := json.Unmarshal(data, &id); err != nil {
return err
}
t.ID = id
return nil
}
type alias TierInfo
var decoded alias
if err := json.Unmarshal(data, &decoded); err != nil {
return err
}
*t = TierInfo(decoded)
return nil
}
type LoadCodeAssistResponse struct {
CurrentTier string `json:"currentTier,omitempty"`
CurrentTier *TierInfo `json:"currentTier,omitempty"`
PaidTier *TierInfo `json:"paidTier,omitempty"`
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
}
// GetTier extracts tier ID, prioritizing paidTier over currentTier
func (r *LoadCodeAssistResponse) GetTier() string {
if r.PaidTier != nil && r.PaidTier.ID != "" {
return r.PaidTier.ID
}
if r.CurrentTier != nil {
return r.CurrentTier.ID
}
return ""
}
type AllowedTier struct {
ID string `json:"id"`
IsDefault bool `json:"isDefault,omitempty"`

View File

@@ -1,3 +1,4 @@
// Package geminicli provides helpers for interacting with Gemini CLI tools.
package geminicli
import "time"
@@ -26,6 +27,12 @@ const (
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
// DefaultScopes for Google One (personal Google accounts with Gemini access)
// Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client,
// Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client
// cannot request restricted scopes like generative-language.retriever or drive.readonly.
DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"

View File

@@ -11,11 +11,12 @@ type Model struct {
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.
const DefaultTestModel = "gemini-3-pro-preview"
const DefaultTestModel = "gemini-2.0-flash"

View File

@@ -19,13 +19,17 @@ type OAuthConfig struct {
}
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
CreatedAt time.Time `json:"created_at"`
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
ProjectID string `json:"project_id,omitempty"`
// TierID is a user-selected fallback tier.
// For oauth types that support auto detection (google_one/code_assist), the server will prefer
// the detected tier and fall back to TierID when detection fails.
TierID string `json:"tier_id,omitempty"`
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
CreatedAt time.Time `json:"created_at"`
}
type SessionStore struct {
@@ -172,23 +176,32 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
if effective.Scopes == "" {
// Use different default scopes based on OAuth type
if oauthType == "ai_studio" {
switch oauthType {
case "ai_studio":
// Built-in client can't request some AI Studio scopes (notably generative-language).
if isBuiltinClient {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = DefaultAIStudioScopes
}
} else {
case "google_one":
// Google One uses built-in Gemini CLI client (same as code_assist)
// Built-in client can't request restricted scopes like generative-language.retriever
if isBuiltinClient {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = DefaultGoogleOneScopes
}
default:
// Default to Code Assist scopes
effective.Scopes = DefaultCodeAssistScopes
}
} else if oauthType == "ai_studio" && isBuiltinClient {
} else if (oauthType == "ai_studio" || oauthType == "google_one") && isBuiltinClient {
// If user overrides scopes while still using the built-in client, strip restricted scopes.
parts := strings.Fields(effective.Scopes)
filtered := make([]string, 0, len(parts))
for _, s := range parts {
if strings.Contains(s, "generative-language") {
if hasRestrictedScope(s) {
continue
}
filtered = append(filtered, s)
@@ -214,6 +227,11 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
return effective, nil
}
func hasRestrictedScope(scope string) bool {
return strings.HasPrefix(scope, "https://www.googleapis.com/auth/generative-language") ||
strings.HasPrefix(scope, "https://www.googleapis.com/auth/drive")
}
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
if err != nil {

View File

@@ -0,0 +1,113 @@
package geminicli
import (
"strings"
"testing"
)
func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
tests := []struct {
name string
input OAuthConfig
oauthType string
wantClientID string
wantScopes string
wantErr bool
}{
{
name: "Google One with built-in client (empty config)",
input: OAuthConfig{},
oauthType: "google_one",
wantClientID: GeminiCLIOAuthClientID,
wantScopes: DefaultCodeAssistScopes,
wantErr: false,
},
{
name: "Google One with custom client",
input: OAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
oauthType: "google_one",
wantClientID: "custom-client-id",
wantScopes: DefaultGoogleOneScopes,
wantErr: false,
},
{
name: "Google One with built-in client and custom scopes (should filter restricted scopes)",
input: OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
},
oauthType: "google_one",
wantClientID: GeminiCLIOAuthClientID,
wantScopes: "https://www.googleapis.com/auth/cloud-platform",
wantErr: false,
},
{
name: "Google One with built-in client and only restricted scopes (should fallback to default)",
input: OAuthConfig{
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
},
oauthType: "google_one",
wantClientID: GeminiCLIOAuthClientID,
wantScopes: DefaultCodeAssistScopes,
wantErr: false,
},
{
name: "Code Assist with built-in client",
input: OAuthConfig{},
oauthType: "code_assist",
wantClientID: GeminiCLIOAuthClientID,
wantScopes: DefaultCodeAssistScopes,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := EffectiveOAuthConfig(tt.input, tt.oauthType)
if (err != nil) != tt.wantErr {
t.Errorf("EffectiveOAuthConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
return
}
if got.ClientID != tt.wantClientID {
t.Errorf("EffectiveOAuthConfig() ClientID = %v, want %v", got.ClientID, tt.wantClientID)
}
if got.Scopes != tt.wantScopes {
t.Errorf("EffectiveOAuthConfig() Scopes = %v, want %v", got.Scopes, tt.wantScopes)
}
})
}
}
func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
// Test that Google One with built-in client filters out restricted scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile",
}, "google_one")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// Should only contain cloud-platform, userinfo.email, and userinfo.profile
// Should NOT contain generative-language or drive scopes
if strings.Contains(cfg.Scopes, "generative-language") {
t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes)
}
if strings.Contains(cfg.Scopes, "drive") {
t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "cloud-platform") {
t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "userinfo.email") {
t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "userinfo.profile") {
t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes)
}
}

View File

@@ -1,3 +1,4 @@
// Package googleapi provides helpers for Google-style API responses.
package googleapi
import "net/http"

View File

@@ -1,3 +1,4 @@
// Package oauth provides helpers for OAuth flows used by this service.
package oauth
import (

View File

@@ -1,3 +1,4 @@
// Package openai provides helpers and types for OpenAI API integration.
package openai
import _ "embed"

View File

@@ -327,7 +327,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return &claims, nil
}
// ExtractUserInfo extracts user information from ID Token claims
// UserInfo represents user information extracted from ID Token claims.
type UserInfo struct {
Email string
ChatGPTAccountID string

View File

@@ -1,3 +1,4 @@
// Package pagination provides types and helpers for paginated responses.
package pagination
// PaginationParams 分页参数

View File

@@ -1,3 +1,4 @@
// Package response provides standardized HTTP response helpers.
package response
import (

View File

@@ -1,3 +1,4 @@
// Package sysutil provides system-level utilities for process management.
package sysutil
import (

View File

@@ -1,3 +1,4 @@
// Package usagestats provides types for usage statistics and reporting.
package usagestats
import "time"
@@ -10,8 +11,8 @@ type DashboardStats struct {
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
TotalAPIKeys int64 `json:"total_api_keys"`
ActiveAPIKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
@@ -82,10 +83,10 @@ type UserUsageTrendPoint struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
APIKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
@@ -94,8 +95,8 @@ type ApiKeyUsageTrendPoint struct {
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
TotalAPIKeys int64 `json:"total_api_keys"`
ActiveAPIKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
@@ -128,7 +129,7 @@ type UserDashboardStats struct {
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
ApiKeyID int64
APIKeyID int64
AccountID int64
GroupID int64
Model string
@@ -157,9 +158,9 @@ type BatchUserUsageStats struct {
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct {
ApiKeyID int64 `json:"api_key_id"`
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats struct {
APIKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}

View File

@@ -43,6 +43,11 @@ type accountRepository struct {
sql sqlExecutor // 原生 SQL 执行接口
}
type tempUnschedSnapshot struct {
until *time.Time
reason string
}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
@@ -165,6 +170,11 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
accountIDs = append(accountIDs, acc.ID)
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -191,6 +201,10 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outByID[entAcc.ID] = out
}
@@ -550,6 +564,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco
Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -575,6 +590,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -607,6 +623,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
dbaccount.PlatformIn(platforms...),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -648,6 +665,31 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
return err
}
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = $1,
temp_unschedulable_reason = $2,
updated_at = NOW()
WHERE id = $3
AND deleted_at IS NULL
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
`, until, reason, id)
return err
}
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = NULL,
temp_unschedulable_reason = NULL,
updated_at = NOW()
WHERE id = $1
AND deleted_at IS NULL
`, id)
return err
}
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -808,6 +850,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in
now := time.Now()
preds = append(preds,
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
)
@@ -869,6 +912,10 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if err != nil {
return nil, err
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -894,12 +941,68 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[acc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outAccounts = append(outAccounts, *out)
}
return outAccounts, nil
}
func tempUnschedulablePredicate() dbpredicate.Account {
return dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
})
}
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 {
return out, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
FROM accounts
WHERE id = ANY($1)
`, pq.Array(accountIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
var until sql.NullTime
var reason sql.NullString
if err := rows.Scan(&id, &until, &reason); err != nil {
return nil, err
}
var untilPtr *time.Time
if until.Valid {
tmp := until.Time
untilPtr = &tmp
}
if reason.Valid {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
} else {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {

View File

@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
name: "filter_by_type",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
},
accType: service.AccountTypeApiKey,
accType: service.AccountTypeAPIKey,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
},
},
{

View File

@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewApiKeyRepository(entClient)
apiKeyRepo := NewAPIKeyRepository(entClient)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
}
require.NoError(t, userRepo.Create(ctx, u))
key := &service.ApiKey{
key := &service.APIKey{
UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key",

View File

@@ -24,7 +24,7 @@ type apiKeyCache struct {
rdb *redis.Client
}
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
return &apiKeyCache{rdb: rdb}
}

View File

@@ -16,17 +16,17 @@ type apiKeyRepository struct {
client *dbent.Client
}
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
return &apiKeyRepository{client: client}
}
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
created, err := r.client.ApiKey.Create().
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.APIKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
WithUser().
@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 不加载完整的 ApiKey 实体及其关联数据User、Group 等)
// - 不加载完整的 API Key 实体及其关联数据User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
m, err := r.activeQuery().
@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrApiKeyNotFound
return 0, service.ErrAPIKeyNotFound
}
return 0, err
}
return m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
WithUser().
@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at避免二次查询带来的并发可见性问题。
now := time.Now()
builder := r.client.ApiKey.Update().
builder := r.client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
}
if affected == 0 {
// 更新影响行数为 0说明记录不存在或已被软删除。
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.ApiKey.Update().
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
return err
}
if affected == 0 {
exists, err := r.client.ApiKey.Query().
exists, err := r.client.APIKey.Query().
Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx))
if err != nil {
@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
if exists {
return nil
}
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
total, err := q.Count(ctx)
@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
return []int64{}, nil
}
ids, err := r.client.ApiKey.Query().
ids, err := r.client.APIKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx)
if err != nil {
@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
total, err := q.Count(ctx)
@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
q := r.activeQuery()
if userID > 0 {
q = q.Where(apikey.UserIDEQ(userID))
@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.ApiKey.Update().
n, err := r.client.APIKey.Update().
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx)
@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err
}
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil
}
out := &service.ApiKey{
out := &service.APIKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,

View File

@@ -12,30 +12,30 @@ import (
"github.com/stretchr/testify/suite"
)
type ApiKeyRepoSuite struct {
type APIKeyRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *apiKeyRepository
}
func (s *ApiKeyRepoSuite) SetupTest() {
func (s *APIKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
}
func TestApiKeyRepoSuite(t *testing.T) {
suite.Run(t, new(ApiKeyRepoSuite))
func TestAPIKeyRepoSuite(t *testing.T) {
suite.Run(t, new(APIKeyRepoSuite))
}
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
func (s *APIKeyRepoSuite) TestCreate() {
user := s.mustCreateUser("create@test.com")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
@@ -51,16 +51,16 @@ func (s *ApiKeyRepoSuite) TestCreate() {
s.Require().Equal("sk-create-test", got.Key)
}
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
func (s *APIKeyRepoSuite) TestGetByKey() {
user := s.mustCreateUser("getbykey@test.com")
group := s.mustCreateGroup("g-key")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
@@ -78,16 +78,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey() {
s.Require().Equal(group.ID, got.Group.ID)
}
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key")
}
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
func (s *APIKeyRepoSuite) TestUpdate() {
user := s.mustCreateUser("update@test.com")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
@@ -108,10 +108,10 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s.Require().Equal(service.StatusDisabled, got.Status)
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
user := s.mustCreateUser("cleargroup@test.com")
group := s.mustCreateGroup("g-clear")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
@@ -131,9 +131,9 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
func (s *APIKeyRepoSuite) TestDelete() {
user := s.mustCreateUser("delete@test.com")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
@@ -150,7 +150,7 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
func (s *APIKeyRepoSuite) TestListByUserID() {
user := s.mustCreateUser("listbyuser@test.com")
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
@@ -161,7 +161,7 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
s.Require().Equal(int64(2), page.Total)
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
user := s.mustCreateUser("paging@test.com")
for i := 0; i < 5; i++ {
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
@@ -174,7 +174,7 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
s.Require().Equal(3, page.Pages)
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
func (s *APIKeyRepoSuite) TestCountByUserID() {
user := s.mustCreateUser("count@test.com")
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
@@ -186,7 +186,7 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
func (s *APIKeyRepoSuite) TestListByGroupID() {
user := s.mustCreateUser("listbygroup@test.com")
group := s.mustCreateGroup("g-list")
@@ -202,7 +202,7 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
s.Require().NotNil(keys[0].User)
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
func (s *APIKeyRepoSuite) TestCountByGroupID() {
user := s.mustCreateUser("countgroup@test.com")
group := s.mustCreateGroup("g-count")
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
@@ -214,7 +214,7 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
func (s *APIKeyRepoSuite) TestExistsByKey() {
user := s.mustCreateUser("exists@test.com")
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
@@ -227,41 +227,41 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
s.Require().False(notExists)
}
// --- SearchApiKeys ---
// --- SearchAPIKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
user := s.mustCreateUser("search@test.com")
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchAPIKeys")
s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production")
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
user := s.mustCreateUser("searchnokw@test.com")
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
s.Require().Len(found, 2)
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
user := s.mustCreateUser("searchnouid@test.com")
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
s.Require().Len(found, 1)
}
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
user := s.mustCreateUser("cleargrp@test.com")
group := s.mustCreateGroup("g-clear-bulk")
@@ -284,7 +284,7 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := s.mustCreateUser("k@example.com")
group := s.mustCreateGroup("g-k")
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
@@ -320,8 +320,8 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchApiKeys")
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchAPIKeys")
s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID)
@@ -346,7 +346,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
s.T().Helper()
u, err := s.client.User.Create().
@@ -359,7 +359,7 @@ func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
return userEntityToService(u)
}
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
@@ -370,10 +370,10 @@ func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
return groupEntityToService(g)
}
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.APIKey {
s.T().Helper()
k := &service.ApiKey{
k := &service.APIKey{
UserID: userID,
Key: key,
Name: name,

View File

@@ -5,28 +5,20 @@ import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeOAuthServiceSuite struct {
suite.Suite
srv *httptest.Server
client *claudeOAuthService
}
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// requestCapture holds captured request data for assertions in the main goroutine.
type requestCapture struct {
path string
@@ -37,6 +29,12 @@ type requestCapture struct {
contentType string
}
func newTestReqClient(rt http.RoundTripper) *req.Client {
c := req.C()
c.GetClient().Transport = rt
return c
}
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
tests := []struct {
name string
@@ -83,17 +81,17 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.cookies = r.Cookies()
tt.handler(w, r)
}))
defer s.srv.Close()
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
@@ -158,20 +156,20 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.method = r.Method
captured.cookies = r.Cookies()
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
@@ -266,19 +264,19 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
@@ -362,19 +360,19 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.RefreshToken(context.Background(), "rt", "")

View File

@@ -33,7 +33,7 @@ type usageRequestCapture struct {
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
var captured usageRequestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.authorization = r.Header.Get("Authorization")
captured.anthropicBeta = r.Header.Get("anthropic-beta")
@@ -59,7 +59,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "nope")
}))
@@ -73,7 +73,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
@@ -86,7 +86,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never respond - simulate slow server
<-r.Context().Done()
}))

View File

@@ -309,7 +309,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}

View File

@@ -1,4 +1,4 @@
// Package infrastructure 提供应用程序的基础设施层组件。
// Package repository 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package repository

View File

@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *
return a
}
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
t.Helper()
ctx := context.Background()
@@ -257,7 +257,7 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *se
k.Name = "default"
}
create := client.ApiKey.Create().
create := client.APIKey.Create().
SetUserID(k.UserID).
SetKey(k.Key).
SetName(k.Name).

View File

@@ -30,6 +30,7 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client
// - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,

View File

@@ -49,7 +49,7 @@ func (s *GitHubReleaseServiceSuite) TearDownTest() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "100")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
@@ -68,7 +68,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
@@ -95,7 +95,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
@@ -123,7 +123,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
@@ -140,7 +140,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("sum"))
}))
@@ -155,7 +155,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
@@ -168,7 +168,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
@@ -195,7 +195,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("content"))
}))
@@ -233,7 +233,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
]
}`
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
@@ -258,7 +258,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
@@ -274,7 +274,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not valid json"))
}))
@@ -290,7 +290,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
@@ -308,7 +308,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))

View File

@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// 2. Clear group_id for api keys bound to this group.
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
// 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.ApiKey.Update().
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.APIKey.Update().
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx); err != nil {

View File

@@ -3,7 +3,6 @@ package repository
import (
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
@@ -93,7 +92,7 @@ func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
// 验证空代理 URL 时请求直接发送到目标服务器
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
// 创建模拟上游服务器
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct")
}))
s.T().Cleanup(upstream.Close)
@@ -115,7 +114,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// 用于接收代理请求的通道
seen := make(chan string, 1)
// 创建模拟代理服务器
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI // 记录请求 URI
_, _ = io.WriteString(w, "proxied")
}))
@@ -145,7 +144,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
// 验证空字符串代理等同于直连
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct-empty")
}))
s.T().Cleanup(upstream.Close)

View File

@@ -0,0 +1,63 @@
package repository
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
// It captures the request body (if any) and then rewinds it before invoking the handler.
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
var body []byte
if r.Body != nil {
body, _ = io.ReadAll(r.Body)
_ = r.Body.Close()
r.Body = io.NopCloser(bytes.NewReader(body))
}
if capture != nil {
capture(r, body)
}
rec := httptest.NewRecorder()
handler(rec, r)
return rec.Result(), nil
})
}
var (
canListenOnce sync.Once
canListen bool
canListenErr error
)
func localListenerAvailable() bool {
canListenOnce.Do(func() {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
canListenErr = err
canListen = false
return
}
_ = ln.Close()
canListen = true
})
return canListen
}
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
tb.Helper()
if !localListenerAvailable() {
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
}
return httptest.NewServer(handler)
}

View File

@@ -34,7 +34,7 @@ func (s *OpenAIOAuthServiceSuite) TearDownTest() {
}
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.srv = newLocalTestServer(s.T(), handler)
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
}

View File

@@ -32,7 +32,7 @@ func (s *PricingServiceSuite) TearDownTest() {
}
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.srv = newLocalTestServer(s.T(), handler)
}
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {

View File

@@ -31,7 +31,7 @@ func (s *ProxyProbeServiceSuite) TearDownTest() {
}
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler)
s.proxySrv = newLocalTestServer(s.T(), handler)
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {

View File

@@ -41,8 +41,8 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
Name: "soft-delete",
@@ -53,13 +53,13 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
_, err := repo.GetByID(ctx, key.ID)
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.ApiKey.Query().
got, err := client.APIKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
@@ -73,8 +73,8 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
Name: "soft-delete2",
@@ -93,8 +93,8 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
Name: "soft-delete3",
@@ -105,10 +105,10 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "hard delete")
_, err = client.ApiKey.Query().
_, err = client.APIKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")

View File

@@ -0,0 +1,91 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const tempUnschedPrefix = "temp_unsched:account:"
var tempUnschedSetScript = redis.NewScript(`
local key = KEYS[1]
local new_until = tonumber(ARGV[1])
local new_value = ARGV[2]
local new_ttl = tonumber(ARGV[3])
local existing = redis.call('GET', key)
if existing then
local ok, existing_data = pcall(cjson.decode, existing)
if ok and existing_data and existing_data.until_unix then
local existing_until = tonumber(existing_data.until_unix)
if existing_until and new_until <= existing_until then
return 0
end
end
end
redis.call('SET', key, new_value, 'EX', new_ttl)
return 1
`)
type tempUnschedCache struct {
rdb *redis.Client
}
func NewTempUnschedCache(rdb *redis.Client) service.TempUnschedCache {
return &tempUnschedCache{rdb: rdb}
}
// SetTempUnsched 设置临时不可调度状态(只延长不缩短)
func (c *tempUnschedCache) SetTempUnsched(ctx context.Context, accountID int64, state *service.TempUnschedState) error {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
stateJSON, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal state: %w", err)
}
ttl := time.Until(time.Unix(state.UntilUnix, 0))
if ttl <= 0 {
return nil // 已过期,不设置
}
ttlSeconds := int(ttl.Seconds())
if ttlSeconds < 1 {
ttlSeconds = 1
}
_, err = tempUnschedSetScript.Run(ctx, c.rdb, []string{key}, state.UntilUnix, string(stateJSON), ttlSeconds).Result()
return err
}
// GetTempUnsched 获取临时不可调度状态
func (c *tempUnschedCache) GetTempUnsched(ctx context.Context, accountID int64) (*service.TempUnschedState, error) {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
var state service.TempUnschedState
if err := json.Unmarshal([]byte(val), &state); err != nil {
return nil, fmt.Errorf("unmarshal state: %w", err)
}
return &state, nil
}
// DeleteTempUnsched 删除临时不可调度状态
func (c *tempUnschedCache) DeleteTempUnsched(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -3,9 +3,9 @@ package repository
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
@@ -18,7 +18,6 @@ import (
type TurnstileServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
verifier *turnstileVerifier
received chan url.Values
}
@@ -31,20 +30,15 @@ func (s *TurnstileServiceSuite) SetupTest() {
s.verifier = verifier
}
func (s *TurnstileServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
func (s *TurnstileServiceSuite) setupTransport(handler http.HandlerFunc) {
s.verifier.verifyURL = "http://in-process/turnstile"
s.verifier.httpClient = &http.Client{
Transport: newInProcessTransport(handler, nil),
}
}
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.verifier.verifyURL = s.srv.URL
}
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture form data in main goroutine context later
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
@@ -72,7 +66,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
var contentType string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contentType = r.Header.Get("Content-Type")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
@@ -84,7 +78,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
}
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
@@ -105,15 +99,19 @@ func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
}
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
s.verifier.verifyURL = "http://in-process/turnstile"
s.verifier.httpClient = &http.Client{
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("dial failed")
}),
}
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error when server is closed")
}
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-valid-json")
}))
@@ -123,7 +121,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
}
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
Success: false,

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"strings"
@@ -60,9 +61,16 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
return requestCount / 5, tokenCount / 5, nil
}
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
if log == nil {
return nil
return false, nil
}
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL保证与其他更新同事务。
// 无事务时回退到默认的 *sql.DB 执行器。
sqlq := r.sql
if tx := dbent.TxFromContext(ctx); tx != nil {
sqlq = tx.Client()
}
createdAt := log.CreatedAt
@@ -70,6 +78,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
createdAt = time.Now()
}
requestID := strings.TrimSpace(log.RequestID)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier
query := `
@@ -107,6 +118,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
`
@@ -115,11 +127,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
var requestIDArg any
if requestID != "" {
requestIDArg = requestID
}
args := []any{
log.UserID,
log.ApiKeyID,
log.APIKeyID,
log.AccountID,
log.RequestID,
requestIDArg,
log.Model,
groupID,
subscriptionID,
@@ -142,11 +159,20 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
firstToken,
createdAt,
}
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
return err
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
return false, err
}
log.RateMultiplier = rateMultiplier
return false, nil
} else {
return false, err
}
}
log.RateMultiplier = rateMultiplier
return nil
return true, nil
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
@@ -183,7 +209,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
}
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
}
@@ -270,8 +296,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
r.sql,
apiKeyStatsQuery,
[]any{service.StatusActive},
&stats.TotalApiKeys,
&stats.ActiveApiKeys,
&stats.TotalAPIKeys,
&stats.ActiveAPIKeys,
); err != nil {
return nil, err
}
@@ -418,8 +444,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
return &stats, nil
}
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
@@ -623,7 +649,7 @@ func resolveUsageStatsTimezone() string {
return "UTC"
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err
@@ -709,11 +735,11 @@ type ModelStat = usagestats.ModelStat
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -755,10 +781,10 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
}
}()
results = make([]ApiKeyUsageTrendPoint, 0)
results = make([]APIKeyUsageTrendPoint, 0)
for rows.Next() {
var row ApiKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
var row APIKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err
}
results = append(results, row)
@@ -844,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID},
&stats.TotalApiKeys,
&stats.TotalAPIKeys,
); err != nil {
return nil, err
}
@@ -853,7 +879,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive},
&stats.ActiveApiKeys,
&stats.ActiveAPIKeys,
); err != nil {
return nil, err
}
@@ -1023,9 +1049,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.ApiKeyID > 0 {
if filters.APIKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.ApiKeyID)
args = append(args, filters.APIKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
@@ -1145,18 +1171,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
return result, nil
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
result := make(map[int64]*BatchApiKeyUsageStats)
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 {
return result, nil
}
for _, id := range apiKeyIDs {
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
}
query := `
@@ -1582,7 +1608,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if err != nil {
return err
}
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
if err != nil {
return err
}
@@ -1603,8 +1629,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if user, ok := users[logs[i].UserID]; ok {
logs[i].User = user
}
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
logs[i].ApiKey = key
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
logs[i].APIKey = key
}
if acc, ok := accounts[logs[i].AccountID]; ok {
logs[i].Account = acc
@@ -1642,7 +1668,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
for i := range logs {
userIDs[logs[i].UserID] = struct{}{}
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
accountIDs[logs[i].AccountID] = struct{}{}
if logs[i].GroupID != nil {
groupIDs[*logs[i].GroupID] = struct{}{}
@@ -1676,12 +1702,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
return out, nil
}
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
out := make(map[int64]*service.ApiKey)
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
out := make(map[int64]*service.APIKey)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
@@ -1800,7 +1826,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
log := &service.UsageLog{
ID: id,
UserID: userID,
ApiKeyID: apiKeyID,
APIKeyID: apiKeyID,
AccountID: accountID,
Model: model,
InputTokens: inputTokens,

View File

@@ -7,6 +7,8 @@ import (
"testing"
"time"
"github.com/google/uuid"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -35,11 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(), // Generate unique RequestID for each log
Model: "claude-3",
InputTokens: inputTokens,
OutputTokens: outputTokens,
@@ -47,7 +50,8 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
ActualCost: cost,
CreatedAt: createdAt,
}
s.Require().NoError(s.repo.Create(s.ctx, log))
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
return log
}
@@ -55,12 +59,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 10,
@@ -69,14 +73,14 @@ func (s *UsageLogRepoSuite) TestCreate() {
ActualCost: 0.4,
}
err := s.repo.Create(s.ctx, log)
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err, "Create")
s.Require().NotZero(log.ID)
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -96,7 +100,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -112,7 +116,7 @@ func (s *UsageLogRepoSuite) TestDelete() {
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -124,18 +128,18 @@ func (s *UsageLogRepoSuite) TestListByUser() {
s.Require().Equal(int64(2), page.Total)
}
// --- ListByApiKey ---
// --- ListByAPIKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() {
func (s *UsageLogRepoSuite) TestListByAPIKey() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByApiKey")
logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByAPIKey")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
@@ -144,7 +148,7 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -159,7 +163,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -179,7 +183,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -211,8 +215,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
@@ -223,7 +227,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
d1, d2, d3 := 100, 200, 300
logToday := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
GroupID: &group.ID,
@@ -236,11 +240,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
DurationMs: &d1,
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
}
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
_, err = s.repo.Create(s.ctx, logToday)
s.Require().NoError(err, "Create logToday")
logOld := &service.UsageLog{
UserID: userOld.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 5,
@@ -250,11 +255,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
DurationMs: &d2,
CreatedAt: todayStart.Add(-1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
_, err = s.repo.Create(s.ctx, logOld)
s.Require().NoError(err, "Create logOld")
logPerf := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 1,
@@ -264,7 +270,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
DurationMs: &d3,
CreatedAt: now.Add(-30 * time.Second),
}
s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf")
_, err = s.repo.Create(s.ctx, logPerf)
s.Require().NoError(err, "Create logPerf")
stats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats")
@@ -272,8 +279,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch")
s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch")
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
@@ -300,14 +307,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
s.Require().NoError(err, "GetUserDashboardStats")
s.Require().Equal(int64(1), stats.TotalApiKeys)
s.Require().Equal(int64(1), stats.TotalAPIKeys)
s.Require().Equal(int64(1), stats.TotalRequests)
}
@@ -315,7 +322,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -331,8 +338,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
@@ -351,24 +358,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
s.Require().Empty(stats)
}
// --- GetBatchApiKeyUsageStats ---
// --- GetBatchAPIKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
@@ -377,7 +384,7 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -402,7 +409,7 @@ func maxTime(a, b time.Time) time.Time {
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -417,11 +424,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
s.Require().Len(logs, 2)
}
// --- ListByApiKeyAndTimeRange ---
// --- ListByAPIKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -431,8 +438,8 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByAPIKeyAndTimeRange")
s.Require().Len(logs, 2)
}
@@ -440,7 +447,7 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -459,7 +466,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -467,7 +474,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// Create logs with different models
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 10,
@@ -476,11 +483,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
_, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 15,
@@ -489,11 +497,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
ActualCost: 0.6,
CreatedAt: base.Add(30 * time.Minute),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
_, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
log3 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 20,
@@ -502,7 +511,8 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
ActualCost: 0.7,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log3))
_, err = s.repo.Create(s.ctx, log3)
s.Require().NoError(err)
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
@@ -515,7 +525,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
now := time.Now()
@@ -535,7 +545,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -552,7 +562,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -571,7 +581,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -579,7 +589,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// Create logs with different models
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -588,11 +598,12 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
_, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -601,7 +612,8 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
_, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
@@ -618,7 +630,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -646,7 +658,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -665,14 +677,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -681,11 +693,12 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
_, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -694,7 +707,8 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
_, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
@@ -719,7 +733,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
@@ -727,7 +741,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
// Create logs on different days
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -736,11 +750,12 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
ActualCost: 0.4,
CreatedAt: base.Add(12 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
_, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -749,7 +764,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
ActualCost: 0.15,
CreatedAt: base.Add(36 * time.Hour), // next day
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
_, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
startTime := base
endTime := base.Add(72 * time.Hour)
@@ -782,8 +798,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -799,12 +815,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
s.Require().GreaterOrEqual(len(trend), 2)
}
// --- GetApiKeyUsageTrend ---
// --- GetAPIKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -815,14 +831,14 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend")
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -832,8 +848,8 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
s.Require().Len(trend, 2)
}
@@ -841,12 +857,12 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters apiKey")
s.Require().Len(logs, 1)
@@ -855,7 +871,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -874,7 +890,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -885,7 +901,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
StartTime: &startTime,
EndTime: &endTime,
}

View File

@@ -4,12 +4,13 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -17,14 +18,15 @@ import (
type userRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB)
}
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
return &userRepository{client: client}
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
return &userRepository{client: client, sql: sqlq}
}
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
var attrErr error
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
if attrErr != nil {
return nil, nil, attrErr
}
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
return nil
return nil, nil
}
// For each attribute filter, get the set of matching user IDs
// Then intersect all sets to get users matching ALL filters
var resultSet map[int64]struct{}
first := true
if r.sql == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
clauses := make([]string, 0, len(attrs))
args := make([]any, 0, len(attrs)*2+1)
argIndex := 1
for attrID, value := range attrs {
// Query user_attribute_values for this attribute
values, err := r.client.UserAttributeValue.Query().
Where(
userattributevalue.AttributeIDEQ(attrID),
userattributevalue.ValueContainsFold(value),
).
All(ctx)
if err != nil {
continue
}
currentSet := make(map[int64]struct{}, len(values))
for _, v := range values {
currentSet[v.UserID] = struct{}{}
}
if first {
resultSet = currentSet
first = false
} else {
// Intersect with previous results
for userID := range resultSet {
if _, ok := currentSet[userID]; !ok {
delete(resultSet, userID)
}
}
}
// Early exit if no users match
if len(resultSet) == 0 {
return nil
}
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
args = append(args, attrID, "%"+value+"%")
argIndex += 2
}
result := make([]int64, 0, len(resultSet))
for userID := range resultSet {
query := fmt.Sprintf(
`SELECT user_id
FROM user_attribute_values
WHERE %s
GROUP BY user_id
HAVING COUNT(DISTINCT attribute_id) = $%d`,
strings.Join(clauses, " OR "),
argIndex,
)
args = append(args, len(attrs))
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make([]int64, 0)
for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
return nil, scanErr
}
result = append(result, userID)
}
return result
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {

View File

@@ -28,7 +28,7 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewApiKeyRepository,
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewProxyRepository,
@@ -42,7 +42,8 @@ var ProviderSet = wire.NewSet(
// Cache implementations
NewGatewayCache,
NewBillingCache,
NewApiKeyCache,
NewAPIKeyCache,
NewTempUnschedCache,
ProvideConcurrencyCache,
NewEmailCache,
NewIdentityCache,

View File

@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/keys (paginated)",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.apiKeyRepo.MustSeed(&service.ApiKey{
deps.apiKeyRepo.MustSeed(&service.APIKey{
ID: 100,
UserID: 1,
Key: "sk_custom_1234567890",
@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 1,
UserID: 1,
ApiKeyID: 100,
APIKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 10,
@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 2,
UserID: 1,
ApiKeyID: 100,
APIKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 5,
@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 1,
UserID: 1,
ApiKeyID: 100,
APIKeyID: 100,
AccountID: 200,
RequestID: "req_123",
Model: "claude-3",
@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySmtpHost: "smtp.example.com",
service.SettingKeySmtpPort: "587",
service.SettingKeySmtpUsername: "user",
service.SettingKeySmtpPassword: "secret",
service.SettingKeySmtpFrom: "no-reply@example.com",
service.SettingKeySmtpFromName: "Sub2API",
service.SettingKeySmtpUseTLS: "true",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
service.SettingKeySMTPUsername: "user",
service.SettingKeySMTPPassword: "secret",
service.SettingKeySMTPFrom: "no-reply@example.com",
service.SettingKeySMTPFromName: "Sub2API",
service.SettingKeySMTPUseTLS: "true",
service.SettingKeyTurnstileEnabled: "true",
service.SettingKeyTurnstileSiteKey: "site-key",
@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle",
service.SettingKeyApiBaseUrl: "https://api.example.com",
service.SettingKeyAPIBaseURL: "https://api.example.com",
service.SettingKeyContactInfo: "support",
service.SettingKeyDocUrl: "https://docs.example.com",
service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
@@ -308,7 +308,12 @@ func TestAPIContracts(t *testing.T) {
"contact_info": "support",
"doc_url": "https://docs.example.com",
"default_concurrency": 5,
"default_balance": 1.25
"default_balance": 1.25,
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro",
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o"
}
}`,
},
@@ -366,16 +371,16 @@ func newContractDeps(t *testing.T) *contractDeps {
cfg := &config.Config{
Default: config.DefaultConfig{
ApiKeyPrefix: "sk-",
APIKeyPrefix: "sk-",
},
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo)
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo)
usageService := service.NewUsageService(usageRepo, userRepo, nil)
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
@@ -664,20 +669,20 @@ type stubApiKeyRepo struct {
now time.Time
nextID int64
byID map[int64]*service.ApiKey
byKey map[string]*service.ApiKey
byID map[int64]*service.APIKey
byKey map[string]*service.APIKey
}
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
return &stubApiKeyRepo{
now: now,
nextID: 100,
byID: make(map[int64]*service.ApiKey),
byKey: make(map[string]*service.ApiKey),
byID: make(map[int64]*service.APIKey),
byKey: make(map[string]*service.APIKey),
}
}
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
func (r *stubApiKeyRepo) MustSeed(key *service.APIKey) {
if key == nil {
return
}
@@ -686,7 +691,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
r.byKey[clone.Key] = &clone
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
if key == nil {
return errors.New("nil key")
}
@@ -706,10 +711,10 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error
return nil
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
key, ok := r.byID[id]
if !ok {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
clone := *key
return &clone, nil
@@ -718,26 +723,26 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
key, ok := r.byID[id]
if !ok {
return 0, service.ErrApiKeyNotFound
return 0, service.ErrAPIKeyNotFound
}
return key.UserID, nil
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
found, ok := r.byKey[key]
if !ok {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
clone := *found
return &clone, nil
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
if key == nil {
return errors.New("nil key")
}
if _, ok := r.byID[key.ID]; !ok {
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now
@@ -751,14 +756,14 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
key, ok := r.byID[id]
if !ok {
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
delete(r.byID, id)
delete(r.byKey, key.Key)
return nil
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID))
for id := range r.byID {
if r.byID[id].UserID == userID {
@@ -776,7 +781,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
end = len(ids)
}
out := make([]service.ApiKey, 0, end-start)
out := make([]service.APIKey, 0, end-start)
for _, id := range ids[start:end] {
clone := *r.byID[id]
out = append(out, clone)
@@ -830,11 +835,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err
return ok, nil
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented")
}
@@ -858,8 +863,8 @@ func (r *stubUsageLogRepo) SetUserLogs(userID int64, logs []service.UsageLog) {
r.userLogs[userID] = logs
}
func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) error {
return errors.New("not implemented")
func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
@@ -877,7 +882,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params
return out, paginationResult(total, params), nil
}
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -890,7 +895,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
}
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -922,7 +927,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
@@ -975,7 +980,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in
}, nil
}
func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
@@ -995,7 +1000,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, errors.New("not implemented")
}
@@ -1017,8 +1022,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio
// Apply filters
var filtered []service.UsageLog
for _, log := range logs {
// Apply ApiKeyID filter
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID {
// Apply APIKeyID filter
if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID {
continue
}
// Apply Model filter
@@ -1151,8 +1156,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati
// Ensure compile-time interface compliance.
var (
_ service.UserRepository = (*stubUserRepo)(nil)
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil)
_ service.ApiKeyCache = (*stubApiKeyCache)(nil)
_ service.APIKeyRepository = (*stubApiKeyRepo)(nil)
_ service.APIKeyCache = (*stubApiKeyCache)(nil)
_ service.GroupRepository = (*stubGroupRepo)(nil)
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil)

View File

@@ -1,3 +1,4 @@
// Package server provides HTTP server initialization and configuration.
package server
import (
@@ -25,8 +26,8 @@ func ProvideRouter(
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
) *gin.Engine {
if cfg.Server.Mode == "release" {

View File

@@ -1,3 +1,4 @@
// Package middleware provides HTTP middleware for authentication, authorization, and request processing.
package middleware
import (
@@ -32,7 +33,7 @@ func adminAuth(
// 检查 x-api-key headerAdmin API Key 认证)
apiKey := c.GetHeader("x-api-key")
if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userService) {
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
return
}
c.Next()
@@ -57,14 +58,14 @@ func adminAuth(
}
}
// validateAdminApiKey 验证管理员 API Key
func validateAdminApiKey(
// validateAdminAPIKey 验证管理员 API Key
func validateAdminAPIKey(
c *gin.Context,
key string,
settingService *service.SettingService,
userService *service.UserService,
) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false

View File

@@ -11,13 +11,13 @@ import (
"github.com/gin-gonic/gin"
)
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
// NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) APIKeyAuthMiddleware {
return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
}
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
@@ -60,7 +60,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) {
if errors.Is(err, service.ErrAPIKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
@@ -88,7 +88,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
@@ -146,7 +146,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
}
// 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
@@ -157,13 +157,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
}
}
// GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey))
// GetAPIKeyFromContext 从上下文中获取API key
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
value, exists := c.Get(string(ContextKeyAPIKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*service.ApiKey)
apiKey, ok := value.(*service.APIKey)
return apiKey, ok
}

View File

@@ -11,16 +11,16 @@ import (
"github.com/gin-gonic/gin"
)
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
// APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
}
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// APIKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
//
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" {
@@ -30,7 +30,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) {
if errors.Is(err, service.ErrAPIKeyNotFound) {
abortWithGoogleError(c, 401, "Invalid API key")
return
}
@@ -53,7 +53,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
// 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
@@ -92,7 +92,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
}
}
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,

View File

@@ -16,53 +16,53 @@ import (
"github.com/stretchr/testify/require"
)
type fakeApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
type fakeAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
}
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if f.getByKey == nil {
return nil, errors.New("unexpected call")
}
return f.getByKey(ctx, key)
}
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
@@ -74,8 +74,8 @@ type googleErrorResponse struct {
} `json:"error"`
}
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
return service.NewApiKeyService(
func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
return service.NewAPIKeyService(
repo,
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
@@ -89,12 +89,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -113,12 +113,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, service.ErrApiKeyNotFound
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, service.ErrAPIKeyNotFound
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -138,12 +138,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("db down")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -163,9 +163,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusDisabled,
@@ -176,7 +176,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -196,9 +196,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusActive,
@@ -210,7 +210,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)

View File

@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
Balance: 10,
Concurrency: 3,
}
apiKey := &service.ApiKey{
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
@@ -46,9 +46,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
@@ -110,9 +110,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
}
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New()
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
@@ -120,14 +120,14 @@ func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService
}
type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented")
}
@@ -135,14 +135,14 @@ func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if r.getByKey != nil {
return r.getByKey(ctx, key)
}
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
@@ -150,7 +150,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -166,11 +166,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err
return false, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented")
}

View File

@@ -15,8 +15,8 @@ const (
ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色string
ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key"
// ContextKeyAPIKey API密钥上下文键
ContextKeyAPIKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)

View File

@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc
// APIKeyAuthMiddleware API Key 认证中间件类型
type APIKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware,
NewAdminAuthMiddleware,
NewApiKeyAuthMiddleware,
NewAPIKeyAuthMiddleware,
)

View File

@@ -17,8 +17,8 @@ func SetupRouter(
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) *gin.Engine {
@@ -43,8 +43,8 @@ func registerRoutes(
h *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {

View File

@@ -1,3 +1,4 @@
// Package routes provides HTTP route registration and handlers.
package routes
import (
@@ -67,10 +68,10 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
}
}
@@ -123,6 +124,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
@@ -203,12 +206,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
}
}
@@ -248,7 +251,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
}
}

View File

@@ -13,8 +13,8 @@ import (
func RegisterGatewayRoutes(
r *gin.Engine,
h *handler.Handlers,
apiKeyAuth middleware.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {
@@ -36,7 +36,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta")
gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
@@ -65,7 +65,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)

View File

@@ -50,7 +50,7 @@ func RegisterUserRoutes(
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
}
// 卡密兑换

View File

@@ -1,3 +1,4 @@
// Package service provides business logic and domain services for the application.
package service
import (
@@ -29,6 +30,9 @@ type Account struct {
RateLimitResetAt *time.Time
OverloadUntil *time.Time
TempUnschedulableUntil *time.Time
TempUnschedulableReason string
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
@@ -39,6 +43,13 @@ type Account struct {
Groups []*Group
}
type TempUnschedulableRule struct {
ErrorCode int `json:"error_code"`
Keywords []string `json:"keywords"`
DurationMinutes int `json:"duration_minutes"`
Description string `json:"description"`
}
func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
@@ -54,6 +65,9 @@ func (a *Account) IsSchedulable() bool {
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
return true
}
@@ -92,10 +106,7 @@ func (a *Account) GeminiOAuthType() string {
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
return tierID
}
func (a *Account) IsGeminiCodeAssist() bool {
@@ -163,6 +174,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil
}
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false
}
raw, ok := a.Credentials["temp_unschedulable_enabled"]
if !ok || raw == nil {
return false
}
enabled, ok := raw.(bool)
return ok && enabled
}
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["temp_unschedulable_rules"]
if !ok || raw == nil {
return nil
}
arr, ok := raw.([]any)
if !ok {
return nil
}
rules := make([]TempUnschedulableRule, 0, len(arr))
for _, item := range arr {
entry, ok := item.(map[string]any)
if !ok || entry == nil {
continue
}
rule := TempUnschedulableRule{
ErrorCode: parseTempUnschedInt(entry["error_code"]),
Keywords: parseTempUnschedStrings(entry["keywords"]),
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
Description: parseTempUnschedString(entry["description"]),
}
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
continue
}
rules = append(rules, rule)
}
return rules
}
func parseTempUnschedString(value any) string {
s, ok := value.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func parseTempUnschedStrings(value any) []string {
if value == nil {
return nil
}
var raw []string
switch v := value.(type) {
case []string:
raw = v
case []any:
raw = make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
raw = append(raw, s)
}
}
default:
return nil
}
out := make([]string, 0, len(raw))
for _, item := range raw {
s := strings.TrimSpace(item)
if s != "" {
out = append(out, s)
}
}
return out
}
func parseTempUnschedInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
@@ -206,7 +325,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
if a.Type != AccountTypeAPIKey {
return ""
}
baseURL := a.GetCredential("base_url")
@@ -229,7 +348,7 @@ func (a *Account) GetExtraString(key string) string {
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
@@ -301,14 +420,14 @@ func (a *Account) IsOpenAIOAuth() bool {
}
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
}
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
if a.Type == AccountTypeAPIKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL

View File

@@ -49,6 +49,8 @@ type AccountRepository interface {
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error

View File

@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim
panic("unexpected SetOverloaded call")
}
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
panic("unexpected SetTempUnschedulable call")
}
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
panic("unexpected ClearTempUnschedulable call")
}
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call")
}

View File

@@ -369,7 +369,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
}
// For API Key accounts with model mapping, map the model
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
@@ -393,7 +393,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)

View File

@@ -12,16 +12,18 @@ import (
)
type UsageLogRepository interface {
Create(ctx context.Context, log *UsageLog) error
// Create creates a usage log and returns whether it was actually inserted.
// inserted is false when the insert was skipped due to conflict (idempotent retries).
Create(ctx context.Context, log *UsageLog) (inserted bool, err error)
GetByID(ctx context.Context, id int64) (*UsageLog, error)
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
@@ -32,10 +34,10 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
@@ -51,7 +53,7 @@ type UsageLogRepository interface {
// Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
@@ -105,6 +107,8 @@ type UsageProgress struct {
ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
UsedRequests int64 `json:"used_requests,omitempty"`
LimitRequests int64 `json:"limit_requests,omitempty"`
}
// AntigravityModelQuota Antigravity 单个模型的配额信息
@@ -115,12 +119,16 @@ type AntigravityModelQuota struct {
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiSharedDaily *UsageProgress `json:"gemini_shared_daily,omitempty"` // Gemini shared pool RPD (Google One / Code Assist)
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
GeminiSharedMinute *UsageProgress `json:"gemini_shared_minute,omitempty"` // Gemini shared pool RPM (Google One / Code Assist)
GeminiProMinute *UsageProgress `json:"gemini_pro_minute,omitempty"` // Gemini Pro RPM
GeminiFlashMinute *UsageProgress `json:"gemini_flash_minute,omitempty"` // Gemini Flash RPM
// Antigravity 多模型配额
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
@@ -256,17 +264,44 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
return usage, nil
}
start := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
dayTotals := geminiAggregateUsage(stats)
dailyResetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
// Daily window (RPD)
if quota.SharedRPD > 0 {
totalReq := dayTotals.ProRequests + dayTotals.FlashRequests
totalTokens := dayTotals.ProTokens + dayTotals.FlashTokens
totalCost := dayTotals.ProCost + dayTotals.FlashCost
usage.GeminiSharedDaily = buildGeminiUsageProgress(totalReq, quota.SharedRPD, dailyResetAt, totalTokens, totalCost, now)
} else {
usage.GeminiProDaily = buildGeminiUsageProgress(dayTotals.ProRequests, quota.ProRPD, dailyResetAt, dayTotals.ProTokens, dayTotals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(dayTotals.FlashRequests, quota.FlashRPD, dailyResetAt, dayTotals.FlashTokens, dayTotals.FlashCost, now)
}
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
minuteTotals := geminiAggregateUsage(minuteStats)
if quota.SharedRPM > 0 {
totalReq := minuteTotals.ProRequests + minuteTotals.FlashRequests
totalTokens := minuteTotals.ProTokens + minuteTotals.FlashTokens
totalCost := minuteTotals.ProCost + minuteTotals.FlashCost
usage.GeminiSharedMinute = buildGeminiUsageProgress(totalReq, quota.SharedRPM, minuteResetAt, totalTokens, totalCost, now)
} else {
usage.GeminiProMinute = buildGeminiUsageProgress(minuteTotals.ProRequests, quota.ProRPM, minuteResetAt, minuteTotals.ProTokens, minuteTotals.ProCost, now)
usage.GeminiFlashMinute = buildGeminiUsageProgress(minuteTotals.FlashRequests, quota.FlashRPM, minuteResetAt, minuteTotals.FlashTokens, minuteTotals.FlashCost, now)
}
return usage, nil
}
@@ -506,6 +541,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
}
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
// limit <= 0 means "no local quota window" (unknown or unlimited).
if limit <= 0 {
return nil
}
@@ -519,6 +555,8 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
Utilization: utilization,
ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds,
UsedRequests: used,
LimitRequests: limit,
WindowStats: &WindowStats{
Requests: used,
Tokens: tokens,

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -19,7 +20,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management
@@ -30,7 +31,7 @@ type AdminService interface {
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
@@ -65,7 +66,7 @@ type AdminService interface {
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
}
// Input types for admin operations
// CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct {
Email string
Password string
@@ -122,18 +123,22 @@ type CreateAccountInput struct {
Concurrency int
Priority int
GroupIDs []int64
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
type UpdateAccountInput struct {
Name string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
Name string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
}
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
@@ -147,6 +152,9 @@ type BulkUpdateAccountsInput struct {
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
// BulkUpdateAccountResult captures the result for a single account update.
@@ -220,7 +228,7 @@ type adminServiceImpl struct {
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo ApiKeyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
@@ -232,7 +240,7 @@ func NewAdminService(
groupRepo GroupRepository,
accountRepo AccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo ApiKeyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
@@ -430,7 +438,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
@@ -583,7 +591,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
@@ -620,6 +628,29 @@ func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
break
}
}
}
}
// 检查混合渠道风险(除非用户已确认)
if len(groupIDs) > 0 && !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil {
return nil, err
}
}
account := &Account{
Name: input.Name,
Platform: input.Platform,
@@ -637,22 +668,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID)
break
}
}
}
}
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
return nil, err
@@ -703,6 +718,13 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
return nil, fmt.Errorf("get group: %w", err)
}
}
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil {
return nil, err
}
}
}
if err := s.accountRepo.Update(ctx, account); err != nil {
@@ -731,6 +753,20 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil
}
// Preload account platforms for mixed channel risk checks if group bindings are requested.
platformByID := map[int64]string{}
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
return nil, err
}
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
}
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials,
@@ -762,6 +798,29 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform := platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
platform = account.Platform
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
@@ -1006,3 +1065,77 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
Country: exitInfo.Country,
}, nil
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道Antigravity + Anthropic
// 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
// 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段)
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
// 不是 Antigravity 或 Anthropic无需检查
return nil
}
// 检查每个分组中的其他账号
for _, groupID := range groupIDs {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
// 检查是否存在不同渠道的账号
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue // 跳过当前账号
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue // 不是 Antigravity 或 Anthropic跳过
}
// 检测混合渠道
if currentPlatform != otherPlatform {
group, _ := s.groupRepo.GetByID(ctx, groupID)
groupName := fmt.Sprintf("Group %d", groupID)
if group != nil {
groupName = group.Name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
case PlatformAntigravity:
return "Antigravity"
case PlatformAnthropic, "claude":
return "Anthropic"
default:
return ""
}
}
// MixedChannelError 混合渠道错误
type MixedChannelError struct {
GroupID int64
GroupName string
CurrentPlatform string
OtherPlatform string
}
func (e *MixedChannelError) Error() string {
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
}

View File

@@ -81,6 +81,7 @@ type AntigravityGatewayService struct {
tokenProvider *AntigravityTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
settingService *SettingService
}
func NewAntigravityGatewayService(
@@ -89,12 +90,14 @@ func NewAntigravityGatewayService(
tokenProvider *AntigravityTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
settingService *SettingService,
) *AntigravityGatewayService {
return &AntigravityGatewayService{
accountRepo: accountRepo,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
settingService: settingService,
}
}
@@ -324,6 +327,22 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
return body, nil
}
// isModelNotFoundError 检测是否为模型不存在的 404 错误
func isModelNotFoundError(statusCode int, body []byte) bool {
if statusCode != 404 {
return false
}
bodyStr := strings.ToLower(string(body))
keywords := []string{"model not found", "unknown model", "not found"}
for _, keyword := range keywords {
if strings.Contains(bodyStr, keyword) {
return true
}
}
return true // 404 without specific message also treated as model not found
}
// Forward 转发 Claude 协议请求Claude → Gemini 转换)
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
@@ -417,16 +436,56 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
// 优先检测 thinking block 的 signature 相关错误400并重试一次
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400去除 thinking 后可继续完成请求。
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
retryClaudeReq := claudeReq
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq)
if stripErr == nil && stripped {
log.Printf("Antigravity account %d: detected signature-related 400, retrying once without thinking blocks", account.ID)
retryGeminiBody, txErr := antigravity.TransformClaudeToGemini(&retryClaudeReq, projectID, mappedModel)
if txErr == nil {
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// Retry success: continue normal success flow with the new response.
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
respBody = nil
} else {
// Retry still errored: replace error context with retry response.
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
respBody = retryBody
resp = retryResp
}
} else {
log.Printf("Antigravity account %d: signature retry request failed: %v", account.ID, retryErr)
}
}
}
}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
// 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
}
}
requestID := resp.Header.Get("x-request-id")
@@ -461,6 +520,122 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}, nil
}
func isSignatureRelatedError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
// Fallback: best-effort scan of the raw payload.
msg = strings.ToLower(string(respBody))
}
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
}
func extractAntigravityErrorMessage(body []byte) string {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
// Google-style: {"error": {"message": "..."}}
if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
return msg
}
}
// Fallback: top-level message
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
return msg
}
return ""
}
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
// It also disables top-level `thinking` to prevent dummy-thought injection during retry.
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil {
return false, nil
}
changed := false
if req.Thinking != nil {
req.Thinking = nil
changed = true
}
for i := range req.Messages {
raw := req.Messages[i].Content
if len(raw) == 0 {
continue
}
// If content is a string, nothing to strip.
var str string
if json.Unmarshal(raw, &str) == nil {
continue
}
// Otherwise treat as an array of blocks and convert thinking blocks to text.
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false
for _, block := range blocks {
t, _ := block["type"].(string)
switch t {
case "thinking":
// Convert thinking to text, skip if empty
thinkingText, _ := block["thinking"].(string)
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
case "redacted_thinking":
// Remove redacted_thinking (cannot convert encrypted content)
modifiedAny = true
case "":
// Handle untyped block with "thinking" field
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
} else {
filtered = append(filtered, block)
}
default:
filtered = append(filtered, block)
}
}
if !modifiedAny {
continue
}
newRaw, err := json.Marshal(filtered)
if err != nil {
return changed, err
}
req.Messages[i].Content = newRaw
changed = true
}
return changed, nil
}
// ForwardGemini 转发 Gemini 协议请求
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
@@ -574,14 +749,40 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
defer func() { _ = resp.Body.Close() }()
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
isModelNotFoundError(resp.StatusCode, respBody) {
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
// 关闭原始响应释放连接respBody 已读取到内存)
_ = resp.Body.Close()
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil {
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
if err == nil && fallbackResp.StatusCode < 400 {
resp = fallbackResp
} else if fallbackResp != nil {
_ = fallbackResp.Body.Close()
}
}
}
}
}
// fallback 成功:继续按正常响应处理
if resp.StatusCode < 400 {
goto handleSuccess
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
@@ -589,6 +790,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
// 解包并返回错误
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
@@ -598,6 +803,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
}
handleSuccess:
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
var usage *ClaudeUsage
var firstTokenMs *int

View File

@@ -2,7 +2,7 @@ package service
import "time"
type ApiKey struct {
type APIKey struct {
ID int64
UserID int64
Key string
@@ -15,6 +15,6 @@ type ApiKey struct {
Group *Group
}
func (k *ApiKey) IsActive() bool {
func (k *APIKey) IsActive() bool {
return k.Status == StatusActive
}

View File

@@ -14,39 +14,39 @@ import (
)
var (
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
)
const (
apiKeyMaxErrorsPerHour = 20
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error)
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *ApiKey) error
GetByKey(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
@@ -55,40 +55,40 @@ type ApiKeyCache interface {
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
// CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
// UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo ApiKeyRepository
// APIKeyService API Key服务
type APIKeyService struct {
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache ApiKeyCache
cache APIKeyCache
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo ApiKeyRepository,
// NewAPIKeyService 创建API Key服务实例
func NewAPIKeyService(
apiKeyRepo APIKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
cache ApiKeyCache,
cache APIKeyCache,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
) *APIKeyService {
return &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
@@ -99,7 +99,7 @@ func NewApiKeyService(
}
// GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) {
func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix
prefix := s.cfg.Default.APIKeyPrefix
if prefix == "" {
prefix = "sk-"
}
@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
}
// ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error {
func (s *APIKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrApiKeyTooShort
return ErrAPIKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
c == '_' || c == '-' {
continue
}
return ErrApiKeyInvalidChars
return ErrAPIKeyInvalidChars
}
return nil
}
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
}
if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited
return ErrAPIKeyRateLimited
}
return nil
}
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
@@ -220,8 +220,8 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
if exists {
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
s.incrementAPIKeyErrorCount(ctx, userID)
return nil, ErrAPIKeyExists
}
key = *req.CustomKey
@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// 创建API Key记录
apiKey := &ApiKey{
apiKey := &APIKey{
UserID: userID,
Key: key,
Name: req.Name,
@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
return keys, pagination, nil
}
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 ApiKey 对象及其关联数据User、Group提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 避免加载完整 APIKey 对象及其关联数据User、Group提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil {
@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}

View File

@@ -1,7 +1,7 @@
//go:build unit
// API Key 服务删除方法的单元测试
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理
package service
@@ -16,12 +16,12 @@ import (
"github.com/stretchr/testify/require"
)
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call")
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call")
}
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
return s.ownerID, s.ownerErr
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
panic("unexpected ExistsByKey call")
}
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
panic("unexpected SearchApiKeys call")
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call")
}
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
@@ -142,7 +142,7 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms)
@@ -160,7 +160,7 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err)
@@ -170,17 +170,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrApiKeyNotFound)
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
}
@@ -198,7 +198,7 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err)

View File

@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple {
return nil

View File

@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformAnthropic
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformOpenAI
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformGemini
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID

View File

@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err)
}
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
return stats, nil
}
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -28,7 +28,7 @@ const (
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
AccountTypeAPIKey = "apikey" // API Key类型账号
)
// Redeem type constants
@@ -64,13 +64,13 @@ const (
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
SettingKeySMTPPort = "smtp_port" // SMTP端口
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
SettingKeySMTPPassword = "smtp_password" // SMTP密码加密存储
SettingKeySMTPFrom = "smtp_from" // 发件人地址
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
@@ -81,20 +81,27 @@ const (
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
SettingKeyDocURL = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
// Model fallback settings
SettingKeyEnableModelFallback = "enable_model_fallback"
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
SettingKeyFallbackModelGemini = "fallback_model_gemini"
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
)
// Admin API Key prefix (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-"
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const AdminAPIKeyPrefix = "admin-"

View File

@@ -40,8 +40,8 @@ const (
maxVerifyCodeAttempts = 5
)
// SmtpConfig SMTP配置
type SmtpConfig struct {
// SMTPConfig SMTP配置
type SMTPConfig struct {
Host string
Port int
Username string
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
}
}
// GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
// GetSMTPConfig 从数据库获取SMTP配置
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
keys := []string{
SettingKeySmtpHost,
SettingKeySmtpPort,
SettingKeySmtpUsername,
SettingKeySmtpPassword,
SettingKeySmtpFrom,
SettingKeySmtpFromName,
SettingKeySmtpUseTLS,
SettingKeySMTPHost,
SettingKeySMTPPort,
SettingKeySMTPUsername,
SettingKeySMTPPassword,
SettingKeySMTPFrom,
SettingKeySMTPFromName,
SettingKeySMTPUseTLS,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[SettingKeySmtpHost]
host := settings[SettingKeySMTPHost]
if host == "" {
return nil, ErrEmailNotConfigured
}
port := 587 // 默认端口
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
useTLS := settings[SettingKeySmtpUseTLS] == "true"
useTLS := settings[SettingKeySMTPUseTLS] == "true"
return &SmtpConfig{
return &SMTPConfig{
Host: host,
Port: port,
Username: settings[SettingKeySmtpUsername],
Password: settings[SettingKeySmtpPassword],
From: settings[SettingKeySmtpFrom],
FromName: settings[SettingKeySmtpFromName],
Username: settings[SettingKeySMTPUsername],
Password: settings[SettingKeySMTPPassword],
From: settings[SettingKeySMTPFrom],
FromName: settings[SettingKeySMTPFromName],
UseTLS: useTLS,
}, nil
}
// SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSmtpConfig(ctx)
config, err := s.GetSMTPConfig(ctx)
if err != nil {
return err
}
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
}
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
from := config.From
if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
`, siteName, code)
}
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS {

View File

@@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil
}
@@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
@@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},

View File

@@ -1,6 +1,7 @@
package service
import (
"bytes"
"encoding/json"
"fmt"
)
@@ -70,3 +71,224 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
return parsed, nil
}
// FilterThinkingBlocks removes thinking blocks from request body
// Returns filtered body or original body if filtering fails (fail-safe)
// This prevents 400 errors from invalid thinking block signatures
//
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
func FilterThinkingBlocks(body []byte) []byte {
return filterThinkingBlocksInternal(body, false)
}
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios.
// This is used when upstream returns signature-related 400 errors.
//
// Key insight:
// - User's thinking.type = "enabled" should be PRESERVED (user's intent)
// - Only HISTORICAL assistant messages have thinking blocks with signatures
// - These signatures may be invalid when switching accounts/platforms
// - New responses will generate fresh thinking blocks without signature issues
//
// Strategy:
// - Keep thinking.type = "enabled" (preserve user intent)
// - Remove thinking/redacted_thinking blocks from historical assistant messages
// - Ensure no message has empty content after filtering
func FilterThinkingBlocksForRetry(body []byte) []byte {
// Fast path: check for presence of thinking-related keys in messages
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) {
return body
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
// DO NOT modify thinking.type - preserve user's intent to use thinking mode
// The issue is with historical message signatures, not the thinking mode itself
messages, ok := req["messages"].([]any)
if !ok {
return body
}
modified := false
newMessages := make([]any, 0, len(messages))
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
newMessages = append(newMessages, msg)
continue
}
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any)
if !ok {
// String content or other format - keep as is
newMessages = append(newMessages, msg)
continue
}
newContent := make([]any, 0, len(content))
modifiedThisMsg := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
newContent = append(newContent, block)
continue
}
blockType, _ := blockMap["type"].(string)
// Remove thinking/redacted_thinking blocks from historical messages
// These have signatures that may be invalid across different accounts
if blockType == "thinking" || blockType == "redacted_thinking" {
modifiedThisMsg = true
continue
}
newContent = append(newContent, block)
}
if modifiedThisMsg {
modified = true
// Handle empty content after filtering
if len(newContent) == 0 {
// For assistant messages, skip entirely (remove from conversation)
// For user messages, add placeholder to avoid empty content error
if role == "user" {
newContent = append(newContent, map[string]any{
"type": "text",
"text": "(content removed)",
})
msgMap["content"] = newContent
newMessages = append(newMessages, msgMap)
}
// Skip assistant messages with empty content (don't append)
continue
}
msgMap["content"] = newContent
}
newMessages = append(newMessages, msgMap)
}
if modified {
req["messages"] = newMessages
}
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}
// filterThinkingBlocksInternal removes invalid thinking blocks from request
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
// Fast path: if body doesn't contain "thinking", skip parsing
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"thinking":`)) &&
!bytes.Contains(body, []byte(`"thinking" :`)) {
return body
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
// Check if thinking is enabled
thinkingEnabled := false
if thinking, ok := req["thinking"].(map[string]any); ok {
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
thinkingEnabled = true
}
}
messages, ok := req["messages"].([]any)
if !ok {
return body
}
filtered := false
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
newContent := make([]any, 0, len(content))
filteredThisMessage := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
newContent = append(newContent, block)
continue
}
blockType, _ := blockMap["type"].(string)
if blockType == "thinking" || blockType == "redacted_thinking" {
// When thinking is enabled and this is an assistant message,
// only keep thinking blocks with valid signatures
if thinkingEnabled && role == "assistant" {
signature, _ := blockMap["signature"].(string)
if signature != "" && signature != "skip_thought_signature_validator" {
newContent = append(newContent, block)
continue
}
}
filtered = true
filteredThisMessage = true
continue
}
// Handle blocks without type discriminator but with "thinking" key
if blockType == "" {
if _, hasThinking := blockMap["thinking"]; hasThinking {
filtered = true
filteredThisMessage = true
continue
}
}
newContent = append(newContent, block)
}
if filteredThisMessage {
msgMap["content"] = newContent
}
}
if !filtered {
return body
}
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}

View File

@@ -1,6 +1,7 @@
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
@@ -38,3 +39,115 @@ func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
_, err := ParseGatewayRequest(body)
require.Error(t, err)
}
func TestFilterThinkingBlocks(t *testing.T) {
containsThinkingBlock := func(body []byte) bool {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return false
}
messages, ok := req["messages"].([]any)
if !ok {
return false
}
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
blockType, _ := blockMap["type"].(string)
if blockType == "thinking" {
return true
}
if blockType == "" {
if _, hasThinking := blockMap["thinking"]; hasThinking {
return true
}
}
}
}
return false
}
tests := []struct {
name string
input string
shouldFilter bool
expectError bool
}{
{
name: "filters thinking blocks",
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`,
shouldFilter: true,
},
{
name: "handles no thinking blocks",
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
shouldFilter: false,
},
{
name: "handles invalid JSON gracefully",
input: `{invalid json`,
shouldFilter: false,
expectError: true,
},
{
name: "handles multiple messages with thinking blocks",
input: `{"messages":[{"role":"user","content":[{"type":"text","text":"A"}]},{"role":"assistant","content":[{"type":"thinking","thinking":"think"},{"type":"text","text":"B"}]}]}`,
shouldFilter: true,
},
{
name: "filters thinking blocks without type discriminator",
input: `{"messages":[{"role":"assistant","content":[{"thinking":{"text":"internal"}},{"type":"text","text":"B"}]}]}`,
shouldFilter: true,
},
{
name: "does not filter tool_use input fields named thinking",
input: `{"messages":[{"role":"user","content":[{"type":"tool_use","id":"t1","name":"foo","input":{"thinking":"keepme","x":1}},{"type":"text","text":"Hello"}]}]}`,
shouldFilter: false,
},
{
name: "handles empty messages array",
input: `{"messages":[]}`,
shouldFilter: false,
},
{
name: "handles missing messages field",
input: `{"model":"claude-3"}`,
shouldFilter: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FilterThinkingBlocks([]byte(tt.input))
if tt.expectError {
// For invalid JSON, should return original
require.Equal(t, tt.input, string(result))
return
}
if tt.shouldFilter {
require.False(t, containsThinkingBlock(result))
} else {
// Ensure we don't rewrite JSON when no filtering is needed.
require.Equal(t, tt.input, string(result))
}
// Verify valid JSON returned (unless input was invalid)
var parsed map[string]any
err := json.Unmarshal(result, &parsed)
require.NoError(t, err)
})
}
}

View File

@@ -543,7 +543,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
@@ -579,7 +579,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
@@ -710,7 +710,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
@@ -783,7 +783,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
// 4. 建立粘性绑定
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
@@ -799,7 +799,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
@@ -875,7 +875,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
// 4. 建立粘性绑定
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
@@ -907,7 +907,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -1045,7 +1045,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
@@ -1082,8 +1082,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("upstream request failed: %w", err)
}
// 检查是否需要重试
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
// 优先检测thinking block签名错误400并重试一次
if resp.StatusCode == 400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr == nil {
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// 过滤thinking blocks并重试使用更激进的过滤
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// 使用重试后的响应,继续后续处理
if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded", account.ID)
} else {
log.Printf("Account %d: signature error retry returned status %d", account.ID, retryResp.StatusCode)
}
resp = retryResp
break
}
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
} else {
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
}
// 重试失败,恢复原始响应体继续处理
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
// 不是thinking签名错误恢复响应体
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
// 检查是否需要通用重试排除400因为400已经在上面特殊处理过了
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
@@ -1096,6 +1133,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 不需要重试(成功或不可重试的错误),跳出循环
// DEBUG: 输出响应 headers用于检测 rate limit 信息)
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID)
for k, v := range resp.Header {
log.Printf("[DEBUG] %s: %v", k, v)
}
}
break
}
defer func() { _ = resp.Body.Close() }()
@@ -1119,7 +1163,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
@@ -1179,7 +1223,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
}
@@ -1243,10 +1287,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
@@ -1313,12 +1357,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
return claude.APIKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
return claude.APIKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
@@ -1335,6 +1379,41 @@ func truncateForLog(b []byte, maxBytes int) string {
return s
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// Log for debugging
log.Printf("[SignatureCheck] Checking error message: %s", msg)
// 检测signature相关的错误更宽松的匹配
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
if strings.Contains(msg, "signature") {
log.Printf("[SignatureCheck] Detected signature error")
return true
}
// 检测 thinking block 顺序/类型错误
// 例如: "Expected `thinking` or `redacted_thinking`, but found `text`"
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
log.Printf("[SignatureCheck] Detected thinking block type error")
return true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 例如: "all messages must have non-empty content"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
log.Printf("[SignatureCheck] Detected empty content error")
return true
}
return false
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
@@ -1383,7 +1462,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
@@ -1695,7 +1780,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ApiKey *ApiKey
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
@@ -1704,7 +1789,7 @@ type RecordUsageInput struct {
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -1741,7 +1826,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
@@ -1771,7 +1856,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.SubscriptionID = &subscription.ID
}
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
log.Printf("Create usage log failed: %v", err)
}
@@ -1781,10 +1867,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil
}
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if cost.TotalCost > 0 {
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err)
}
@@ -1793,7 +1881,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if cost.ActualCost > 0 {
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err)
}
@@ -1826,7 +1914,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
@@ -1863,17 +1951,35 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
// 检测 thinking block 签名错误400并重试一次过滤 thinking blocks
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocks(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
resp = retryResp
respBody, err = io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
}
}
}
// 处理错误响应
if resp.StatusCode >= 400 {
// 标记账号状态429/529等
@@ -1912,7 +2018,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
@@ -1971,10 +2077,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}

View File

@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return 999
}
switch a.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(req.Model)
}
@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -539,7 +539,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
@@ -614,7 +621,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
mappedModel := originalModel
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
@@ -636,7 +643,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -825,6 +832,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
@@ -842,6 +853,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
if tempMatched {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
@@ -1614,6 +1628,15 @@ type UpstreamHTTPResult struct {
}
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ========================================")
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
@@ -1644,6 +1667,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
}
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ====================================================")
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
@@ -1758,7 +1790,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
}
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("gemini api_key not configured")
@@ -2177,10 +2209,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
parts := make([]any, 0)
switch content := mm["content"].(type) {
case string:
if strings.TrimSpace(content) != "" {
parts = append(parts, map[string]any{"text": content})
}
// 字符串形式的 content保留所有内容包括空白
parts = append(parts, map[string]any{"text": content})
case []any:
// 如果只有一个 block不过滤空白让上游 API 报错)
singleBlock := len(content) == 1
for _, block := range content {
bm, ok := block.(map[string]any)
if !ok {
@@ -2189,8 +2223,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
bt, _ := bm["type"].(string)
switch bt {
case "text":
if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
if text, ok := bm["text"].(string); ok {
// 单个 block 时保留所有内容(包括空白)
// 多个 blocks 时过滤掉空白
if singleBlock || strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
}
}
case "tool_use":
id, _ := bm["id"].(string)

View File

@@ -121,6 +121,12 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
@@ -275,7 +281,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
},
accountsByID: map[int64]*Account{},

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"regexp"
"strconv"
@@ -18,12 +19,23 @@ import (
)
const (
TierAIPremium = "AI_PREMIUM"
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
TierFree = "FREE"
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
// Canonical tier IDs used by sub2api (2026-aligned).
GeminiTierGoogleOneFree = "google_one_free"
GeminiTierGoogleAIPro = "google_ai_pro"
GeminiTierGoogleAIUltra = "google_ai_ultra"
GeminiTierGCPStandard = "gcp_standard"
GeminiTierGCPEnterprise = "gcp_enterprise"
GeminiTierAIStudioFree = "aistudio_free"
GeminiTierAIStudioPaid = "aistudio_paid"
GeminiTierGoogleOneUnknown = "google_one_unknown"
// Legacy/compat tier IDs that may exist in historical data or upstream responses.
legacyTierAIPremium = "AI_PREMIUM"
legacyTierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
legacyTierGoogleOneBasic = "GOOGLE_ONE_BASIC"
legacyTierFree = "FREE"
legacyTierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
legacyTierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
)
const (
@@ -84,7 +96,7 @@ type GeminiAuthURLResult struct {
State string `json:"state"`
}
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType, tierID string) (*GeminiAuthURLResult, error) {
state, err := geminicli.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
@@ -109,14 +121,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
// - ai_studio: requires a user-provided OAuth client.
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" || oauthType == "google_one" {
if oauthType == "code_assist" {
oauthCfg.ClientID = ""
oauthCfg.ClientSecret = ""
}
@@ -127,6 +139,7 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
ProxyURL: proxyURL,
RedirectURI: redirectURI,
ProjectID: strings.TrimSpace(projectID),
TierID: canonicalGeminiTierIDForOAuthType(oauthType, tierID),
OAuthType: oauthType,
CreatedAt: time.Now(),
}
@@ -146,9 +159,9 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
}
// Redirect URI strategy:
// - code_assist: use Gemini CLI redirect URI (codeassist.google.com/authcode)
// - ai_studio: use localhost callback for manual copy/paste flow
if oauthType == "code_assist" {
// - built-in Gemini CLI OAuth client: use upstream redirect URI (codeassist.google.com/authcode)
// - custom OAuth client: use localhost callback for manual copy/paste flow
if isBuiltinClient {
redirectURI = geminicli.GeminiCLIRedirectURI
} else {
redirectURI = geminicli.AIStudioOAuthRedirectURI
@@ -174,6 +187,9 @@ type GeminiExchangeCodeInput struct {
Code string
ProxyID *int64
OAuthType string // "code_assist" 或 "ai_studio"
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
// If empty, the service will fall back to the tier stored in the OAuth session (if any).
TierID string
}
type GeminiTokenInfo struct {
@@ -185,7 +201,7 @@ type GeminiTokenInfo struct {
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
TierID string `json:"tier_id,omitempty"` // Canonical tier id (e.g. google_one_free, gcp_standard, aistudio_free)
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
}
@@ -204,6 +220,90 @@ func validateTierID(tierID string) error {
return nil
}
func canonicalGeminiTierID(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
lower := strings.ToLower(raw)
switch lower {
case GeminiTierGoogleOneFree,
GeminiTierGoogleAIPro,
GeminiTierGoogleAIUltra,
GeminiTierGCPStandard,
GeminiTierGCPEnterprise,
GeminiTierAIStudioFree,
GeminiTierAIStudioPaid,
GeminiTierGoogleOneUnknown:
return lower
}
upper := strings.ToUpper(raw)
switch upper {
// Google One legacy tiers
case legacyTierAIPremium:
return GeminiTierGoogleAIPro
case legacyTierGoogleOneUnlimited:
return GeminiTierGoogleAIUltra
case legacyTierFree, legacyTierGoogleOneBasic, legacyTierGoogleOneStandard:
return GeminiTierGoogleOneFree
case legacyTierGoogleOneUnknown:
return GeminiTierGoogleOneUnknown
// Code Assist legacy tiers
case "STANDARD", "PRO", "LEGACY":
return GeminiTierGCPStandard
case "ENTERPRISE", "ULTRA":
return GeminiTierGCPEnterprise
}
// Some Code Assist responses use kebab-case tier identifiers.
switch lower {
case "standard-tier", "pro-tier":
return GeminiTierGCPStandard
case "ultra-tier":
return GeminiTierGCPEnterprise
}
return ""
}
func canonicalGeminiTierIDForOAuthType(oauthType, tierID string) string {
oauthType = strings.ToLower(strings.TrimSpace(oauthType))
canonical := canonicalGeminiTierID(tierID)
if canonical == "" {
return ""
}
switch oauthType {
case "google_one":
switch canonical {
case GeminiTierGoogleOneFree, GeminiTierGoogleAIPro, GeminiTierGoogleAIUltra:
return canonical
default:
return ""
}
case "code_assist":
switch canonical {
case GeminiTierGCPStandard, GeminiTierGCPEnterprise:
return canonical
default:
return ""
}
case "ai_studio":
switch canonical {
case GeminiTierAIStudioFree, GeminiTierAIStudioPaid:
return canonical
default:
return ""
}
default:
// Unknown oauth type: accept canonical tier.
return canonical
}
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
@@ -229,45 +329,61 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string
// inferGoogleOneTier infers Google One tier from Drive storage limit
func inferGoogleOneTier(storageBytes int64) string {
log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB))
if storageBytes <= 0 {
return TierGoogleOneUnknown
log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN")
return GeminiTierGoogleOneUnknown
}
if storageBytes > StorageTierUnlimited {
return TierGoogleOneUnlimited
log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited)
return GeminiTierGoogleAIUltra
}
if storageBytes >= StorageTierAIPremium {
return TierAIPremium
}
if storageBytes >= StorageTierStandard {
return TierGoogleOneStandard
}
if storageBytes >= StorageTierBasic {
return TierGoogleOneBasic
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium)
return GeminiTierGoogleAIPro
}
if storageBytes >= StorageTierFree {
return TierFree
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree)
return GeminiTierGoogleOneFree
}
return TierGoogleOneUnknown
log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree)
return GeminiTierGoogleOneUnknown
}
// fetchGoogleOneTier fetches Google One tier from Drive API
// FetchGoogleOneTier fetches Google One tier from Drive API.
// Note: LoadCodeAssist API is NOT called for Google One accounts because:
// 1. It's designed for GCP IAM (enterprise), not personal Google accounts
// 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com
// 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)")
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
log.Printf("[GeminiOAuth] Calling Drive API for storage quota...")
driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil {
// Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") {
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
return TierGoogleOneUnknown, nil, err
log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err)
return GeminiTierGoogleOneUnknown, nil, err
}
// Other errors
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
return TierGoogleOneUnknown, nil, err
log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err)
return GeminiTierGoogleOneUnknown, nil, err
}
log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
tierID := inferGoogleOneTier(storageInfo.Limit)
log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID)
return tierID, storageInfo, nil
}
@@ -326,11 +442,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========")
log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID)
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
log.Printf("[GeminiOAuth] ERROR: Session not found or expired")
return nil, fmt.Errorf("session not found or expired")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
log.Printf("[GeminiOAuth] ERROR: Invalid state")
return nil, fmt.Errorf("invalid state")
}
@@ -341,6 +462,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
proxyURL = proxy.URL()
}
}
log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL)
redirectURI := session.RedirectURI
@@ -349,6 +471,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
if oauthType == "" {
oauthType = "code_assist"
}
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID)
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
if oauthType == "ai_studio" {
@@ -374,8 +498,13 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err)
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
log.Printf("[GeminiOAuth] Token exchange successful")
log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope)
log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn)
sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID)
@@ -391,43 +520,91 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
projectID := sessionProjectID
var tierID string
fallbackTierID := canonicalGeminiTierIDForOAuthType(oauthType, input.TierID)
if fallbackTierID == "" {
fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID)
}
log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========")
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
// 对于 code_assist 模式project_id 是必需的,需要调用 Code Assist API
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id配额由 Google 网关自动识别
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
switch oauthType {
case "code_assist":
log.Printf("[GeminiOAuth] Processing code_assist OAuth type")
if projectID == "" {
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
var err error
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err)
} else {
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID)
}
} else {
log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID)
// 用户手动填了 project_id仍需调用 LoadCodeAssist 获取 tierID
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err)
} else {
tierID = fetchedTierID
log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID)
}
}
if strings.TrimSpace(projectID) == "" {
log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth")
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
// tierID 缺失时使用默认值
// Prefer auto-detected tier; fall back to user-selected tier.
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
if tierID == "" {
tierID = "LEGACY"
if fallbackTierID != "" {
tierID = fallbackTierID
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
} else {
tierID = GeminiTierGCPStandard
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
}
}
log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID)
case "google_one":
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
// Attempt to fetch Drive storage tier
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
var storageInfo *geminicli.DriveStorageInfo
var err error
tierID, storageInfo, err = s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// Log warning but don't block - use fallback
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
tierID = TierGoogleOneUnknown
log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err)
tierID = ""
} else {
log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID)
if storageInfo != nil {
log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
}
}
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
if tierID == "" || tierID == GeminiTierGoogleOneUnknown {
if fallbackTierID != "" {
tierID = fallbackTierID
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
} else {
tierID = GeminiTierGoogleOneFree
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
}
}
fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID)
// Store Drive info in extra field for caching
if storageInfo != nil {
@@ -447,12 +624,25 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
},
}
log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========")
return tokenInfo, nil
}
}
// ai_studio 模式不设置 tierID保持为空
return &GeminiTokenInfo{
case "ai_studio":
// No automatic tier detection for AI Studio OAuth; rely on user selection.
if fallbackTierID != "" {
tierID = fallbackTierID
} else {
tierID = GeminiTierAIStudioFree
}
default:
log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType)
}
log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========")
result := &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
@@ -462,7 +652,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
}, nil
}
log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID)
log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========")
return result, nil
}
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
@@ -558,6 +751,17 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
err = nil
}
}
// Backward compatibility for google_one:
// - New behavior: when a custom OAuth client is configured, google_one will use it.
// - Old behavior: google_one always used the built-in Gemini CLI OAuth client.
// If an existing account was authorized with the built-in client, refreshing with the custom client
// will fail with "unauthorized_client". Retry with the built-in client (code_assist path forces it).
if err != nil && oauthType == "google_one" && strings.Contains(err.Error(), "unauthorized_client") && s.GetOAuthConfig().AIStudioOAuthEnabled {
if alt, altErr := s.RefreshToken(ctx, "code_assist", refreshToken, proxyURL); altErr == nil {
tokenInfo = alt
err = nil
}
}
if err != nil {
// Provide a more actionable error for common OAuth client mismatch issues.
if strings.Contains(err.Error(), "unauthorized_client") {
@@ -583,13 +787,14 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
case "code_assist":
// 先设置默认值或保留旧值,确保 tier_id 始终有值
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = "LEGACY" // 默认值
tokenInfo.TierID = canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
}
if tokenInfo.TierID == "" {
tokenInfo.TierID = GeminiTierGCPStandard
}
// 尝试自动探测 project_id 和 tier_id
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || tokenInfo.TierID == ""
if needDetect {
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
@@ -598,9 +803,10 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
tokenInfo.ProjectID = projectID
}
// 只有当原来没有 tier_id 且探测成功时才更新
if existingTierID == "" && tierID != "" {
tokenInfo.TierID = tierID
if tierID != "" {
if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" {
tokenInfo.TierID = canonical
}
}
}
}
@@ -609,6 +815,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
case "google_one":
canonicalExistingTier := canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
// Check if tier cache is stale (> 24 hours)
needsRefresh := true
if account.Extra != nil {
@@ -617,30 +824,37 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
if time.Since(updatedAt) <= 24*time.Hour {
needsRefresh = false
// Use cached tier
if existingTierID != "" {
tokenInfo.TierID = existingTierID
}
tokenInfo.TierID = canonicalExistingTier
}
}
}
}
if tokenInfo.TierID == "" {
tokenInfo.TierID = canonicalExistingTier
}
if needsRefresh {
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
if err == nil && storageInfo != nil {
tokenInfo.TierID = tierID
tokenInfo.Extra = map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
if err == nil {
if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" && canonical != GeminiTierGoogleOneUnknown {
tokenInfo.TierID = canonical
}
if storageInfo != nil {
tokenInfo.Extra = map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
}
}
}
}
if tokenInfo.TierID == "" || tokenInfo.TierID == GeminiTierGoogleOneUnknown {
if canonicalExistingTier != "" {
tokenInfo.TierID = canonicalExistingTier
} else {
// Fallback to cached or unknown
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = TierGoogleOneUnknown
}
tokenInfo.TierID = GeminiTierGoogleOneFree
}
}
}
@@ -669,6 +883,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
// Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID
fmt.Printf("[GeminiOAuth] Storing tier_id: %s\n", tokenInfo.TierID)
} else {
fmt.Printf("[GeminiOAuth] Invalid tier_id %s: %v\n", tokenInfo.TierID, err)
}
// Silently skip invalid tier_id (don't block account creation)
}
@@ -698,7 +915,13 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY"
if loadResp != nil {
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
// First try to get tier from currentTier/paidTier fields
if tier := loadResp.GetTier(); tier != "" {
tierID = tier
} else {
// Fallback to extracting from allowedTiers
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
}
}
// If LoadCodeAssist returned a project, use it

View File

@@ -1,50 +1,129 @@
package service
import "testing"
import (
"context"
"net/url"
"strings"
"testing"
func TestInferGoogleOneTier(t *testing.T) {
tests := []struct {
name string
storageBytes int64
expectedTier string
}{
{"Negative storage", -1, TierGoogleOneUnknown},
{"Zero storage", 0, TierGoogleOneUnknown},
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
)
// Free tier boundary (15GB)
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
{"Free tier (15GB)", StorageTierFree, TierFree},
func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
t.Parallel()
// Basic tier boundary (100GB)
{"Between free and basic", 50 * GB, TierFree},
{"Just below basic tier", StorageTierBasic - 1, TierFree},
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
type testCase struct {
name string
cfg *config.Config
oauthType string
projectID string
wantClientID string
wantRedirect string
wantScope string
wantProjectID string
wantErrSubstr string
}
// Standard tier boundary (200GB)
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
// AI Premium tier boundary (2TB)
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
// Unlimited tier boundary (> 100TB)
{"Between premium and unlimited", 50 * TB, TierAIPremium},
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
tests := []testCase{
{
name: "google_one uses built-in client when not configured and redirects to upstream",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{},
},
},
oauthType: "google_one",
wantClientID: geminicli.GeminiCLIOAuthClientID,
wantRedirect: geminicli.GeminiCLIRedirectURI,
wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "",
},
{
name: "google_one uses custom client when configured and redirects to localhost",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
},
},
oauthType: "google_one",
wantClientID: "custom-client-id",
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
wantScope: geminicli.DefaultGoogleOneScopes,
wantProjectID: "",
},
{
name: "code_assist always forces built-in client even when custom client configured",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
},
},
oauthType: "code_assist",
projectID: "my-gcp-project",
wantClientID: geminicli.GeminiCLIOAuthClientID,
wantRedirect: geminicli.GeminiCLIRedirectURI,
wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "my-gcp-project",
},
{
name: "ai_studio requires custom client",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{},
},
},
oauthType: "ai_studio",
wantErrSubstr: "AI Studio OAuth requires a custom OAuth Client",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
result := inferGoogleOneTier(tt.storageBytes)
if result != tt.expectedTier {
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
tt.storageBytes, result, tt.expectedTier)
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
if tt.wantErrSubstr != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
}
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
t.Fatalf("expected error containing %q, got: %v", tt.wantErrSubstr, err)
}
return
}
if err != nil {
t.Fatalf("GenerateAuthURL returned error: %v", err)
}
parsed, err := url.Parse(got.AuthURL)
if err != nil {
t.Fatalf("failed to parse auth_url: %v", err)
}
q := parsed.Query()
if gotState := q.Get("state"); gotState != got.State {
t.Fatalf("state mismatch: query=%q result=%q", gotState, got.State)
}
if gotClientID := q.Get("client_id"); gotClientID != tt.wantClientID {
t.Fatalf("client_id mismatch: got=%q want=%q", gotClientID, tt.wantClientID)
}
if gotRedirect := q.Get("redirect_uri"); gotRedirect != tt.wantRedirect {
t.Fatalf("redirect_uri mismatch: got=%q want=%q", gotRedirect, tt.wantRedirect)
}
if gotScope := q.Get("scope"); gotScope != tt.wantScope {
t.Fatalf("scope mismatch: got=%q want=%q", gotScope, tt.wantScope)
}
if gotProjectID := q.Get("project_id"); gotProjectID != tt.wantProjectID {
t.Fatalf("project_id mismatch: got=%q want=%q", gotProjectID, tt.wantProjectID)
}
})
}

View File

@@ -20,13 +20,24 @@ const (
geminiModelFlash geminiModelClass = "flash"
)
type GeminiDailyQuota struct {
ProRPD int64
FlashRPD int64
type GeminiQuota struct {
// SharedRPD is a shared requests-per-day pool across models.
// When SharedRPD > 0, callers should treat ProRPD/FlashRPD as not applicable for daily quota checks.
SharedRPD int64 `json:"shared_rpd,omitempty"`
// SharedRPM is a shared requests-per-minute pool across models.
// When SharedRPM > 0, callers should treat ProRPM/FlashRPM as not applicable for minute quota checks.
SharedRPM int64 `json:"shared_rpm,omitempty"`
// Per-model quotas (AI Studio / API key).
// A value of -1 means "unlimited" (pay-as-you-go).
ProRPD int64 `json:"pro_rpd,omitempty"`
ProRPM int64 `json:"pro_rpm,omitempty"`
FlashRPD int64 `json:"flash_rpd,omitempty"`
FlashRPM int64 `json:"flash_rpm,omitempty"`
}
type GeminiTierPolicy struct {
Quota GeminiDailyQuota
Quota GeminiQuota
Cooldown time.Duration
}
@@ -45,10 +56,27 @@ type GeminiUsageTotals struct {
const geminiQuotaCacheTTL = time.Minute
type geminiQuotaOverrides struct {
type geminiQuotaOverridesV1 struct {
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
}
type geminiQuotaOverridesV2 struct {
QuotaRules map[string]geminiQuotaRuleOverride `json:"quota_rules"`
}
type geminiQuotaRuleOverride struct {
SharedRPD *int64 `json:"shared_rpd,omitempty"`
SharedRPM *int64 `json:"rpm,omitempty"`
GeminiPro *geminiModelQuotaOverride `json:"gemini_pro,omitempty"`
GeminiFlash *geminiModelQuotaOverride `json:"gemini_flash,omitempty"`
Desc *string `json:"desc,omitempty"`
}
type geminiModelQuotaOverride struct {
RPD *int64 `json:"rpd,omitempty"`
RPM *int64 `json:"rpm,omitempty"`
}
type GeminiQuotaService struct {
cfg *config.Config
settingRepo SettingRepository
@@ -82,11 +110,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if s.cfg != nil {
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
raw := []byte(s.cfg.Gemini.Quota.Policy)
var overridesV2 geminiQuotaOverridesV2
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
} else {
policy.ApplyOverrides(overrides.Tiers)
var overridesV1 geminiQuotaOverridesV1
if err := json.Unmarshal(raw, &overridesV1); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
} else {
policy.ApplyOverrides(overridesV1.Tiers)
}
}
}
}
@@ -96,11 +130,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if err != nil && !errors.Is(err, ErrSettingNotFound) {
log.Printf("gemini quota: load setting failed: %v", err)
} else if strings.TrimSpace(value) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
raw := []byte(value)
var overridesV2 geminiQuotaOverridesV2
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
} else {
policy.ApplyOverrides(overrides.Tiers)
var overridesV1 geminiQuotaOverridesV1
if err := json.Unmarshal(raw, &overridesV1); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
} else {
policy.ApplyOverrides(overridesV1.Tiers)
}
}
}
}
@@ -113,12 +153,20 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
return policy
}
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
if account == nil || !account.IsGeminiCodeAssist() {
return GeminiDailyQuota{}, false
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiQuota, bool) {
if account == nil || account.Platform != PlatformGemini {
return GeminiQuota{}, false
}
// Map (oauth_type + tier_id) to a canonical policy tier key.
// This keeps the policy table stable even if upstream tier_id strings vary.
tierKey := geminiQuotaTierKeyForAccount(account)
if tierKey == "" {
return GeminiQuota{}, false
}
policy := s.Policy(ctx)
return policy.QuotaForTier(account.GeminiTierID())
return policy.QuotaForTier(tierKey)
}
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
@@ -126,12 +174,36 @@ func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string)
return policy.CooldownForTier(tierID)
}
func (s *GeminiQuotaService) CooldownForAccount(ctx context.Context, account *Account) time.Duration {
if s == nil || account == nil || account.Platform != PlatformGemini {
return 5 * time.Minute
}
tierKey := geminiQuotaTierKeyForAccount(account)
if strings.TrimSpace(tierKey) == "" {
return 5 * time.Minute
}
return s.CooldownForTier(ctx, tierKey)
}
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
return &GeminiQuotaPolicy{
tiers: map[string]GeminiTierPolicy{
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
// --- AI Studio / API Key (per-model) ---
// aistudio_free:
// - gemini_pro: 50 RPD / 2 RPM
// - gemini_flash: 1500 RPD / 15 RPM
GeminiTierAIStudioFree: {Quota: GeminiQuota{ProRPD: 50, ProRPM: 2, FlashRPD: 1500, FlashRPM: 15}, Cooldown: 30 * time.Minute},
// aistudio_paid: -1 means "unlimited/pay-as-you-go" for RPD.
GeminiTierAIStudioPaid: {Quota: GeminiQuota{ProRPD: -1, ProRPM: 1000, FlashRPD: -1, FlashRPM: 2000}, Cooldown: 5 * time.Minute},
// --- Google One (shared pool) ---
GeminiTierGoogleOneFree: {Quota: GeminiQuota{SharedRPD: 1000, SharedRPM: 60}, Cooldown: 30 * time.Minute},
GeminiTierGoogleAIPro: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
GeminiTierGoogleAIUltra: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
// --- GCP Code Assist (shared pool) ---
GeminiTierGCPStandard: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
GeminiTierGCPEnterprise: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
},
}
}
@@ -149,11 +221,22 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
// Backward-compatible overrides:
// - If the tier uses shared quota, interpret pro_rpd as shared_rpd.
// - Otherwise apply per-model overrides.
if override.ProRPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
if policy.Quota.SharedRPD > 0 {
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
} else {
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
}
}
if override.FlashRPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
if policy.Quota.SharedRPD > 0 {
// No separate flash RPD for shared tiers.
} else {
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.FlashRPD)
}
}
if override.CooldownMinutes != nil {
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
@@ -163,10 +246,51 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
func (p *GeminiQuotaPolicy) ApplyQuotaRulesOverrides(rules map[string]geminiQuotaRuleOverride) {
if p == nil || len(rules) == 0 {
return
}
for rawID, override := range rules {
tierID := normalizeGeminiTierID(rawID)
if tierID == "" {
continue
}
policy, ok := p.tiers[tierID]
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
if override.SharedRPD != nil {
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.SharedRPD)
}
if override.SharedRPM != nil {
policy.Quota.SharedRPM = clampGeminiQuotaRPM(*override.SharedRPM)
}
if override.GeminiPro != nil {
if override.GeminiPro.RPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiPro.RPD)
}
if override.GeminiPro.RPM != nil {
policy.Quota.ProRPM = clampGeminiQuotaRPM(*override.GeminiPro.RPM)
}
}
if override.GeminiFlash != nil {
if override.GeminiFlash.RPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiFlash.RPD)
}
if override.GeminiFlash.RPM != nil {
policy.Quota.FlashRPM = clampGeminiQuotaRPM(*override.GeminiFlash.RPM)
}
}
p.tiers[tierID] = policy
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiQuota, bool) {
policy, ok := p.policyForTier(tierID)
if !ok {
return GeminiDailyQuota{}, false
return GeminiQuota{}, false
}
return policy.Quota, true
}
@@ -184,22 +308,43 @@ func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool
return GeminiTierPolicy{}, false
}
normalized := normalizeGeminiTierID(tierID)
if normalized == "" {
normalized = "LEGACY"
}
if policy, ok := p.tiers[normalized]; ok {
return policy, true
}
policy, ok := p.tiers["LEGACY"]
return policy, ok
return GeminiTierPolicy{}, false
}
func normalizeGeminiTierID(tierID string) string {
return strings.ToUpper(strings.TrimSpace(tierID))
tierID = strings.TrimSpace(tierID)
if tierID == "" {
return ""
}
// Prefer canonical mapping (handles legacy tier strings).
if canonical := canonicalGeminiTierID(tierID); canonical != "" {
return canonical
}
// Accept older policy keys that used uppercase names.
switch strings.ToUpper(tierID) {
case "AISTUDIO_FREE":
return GeminiTierAIStudioFree
case "AISTUDIO_PAID":
return GeminiTierAIStudioPaid
case "GOOGLE_ONE_FREE":
return GeminiTierGoogleOneFree
case "GOOGLE_AI_PRO":
return GeminiTierGoogleAIPro
case "GOOGLE_AI_ULTRA":
return GeminiTierGoogleAIUltra
case "GCP_STANDARD":
return GeminiTierGCPStandard
case "GCP_ENTERPRISE":
return GeminiTierGCPEnterprise
}
return strings.ToLower(tierID)
}
func clampGeminiQuotaInt64(value int64) int64 {
if value < 0 {
func clampGeminiQuotaInt64WithUnlimited(value int64) int64 {
if value < -1 {
return 0
}
return value
@@ -212,11 +357,46 @@ func clampGeminiQuotaInt(value int) int {
return value
}
func clampGeminiQuotaRPM(value int64) int64 {
if value < 0 {
return 0
}
return value
}
func geminiCooldownForTier(tierID string) time.Duration {
policy := newGeminiQuotaPolicy()
return policy.CooldownForTier(tierID)
}
func geminiQuotaTierKeyForAccount(account *Account) string {
if account == nil || account.Platform != PlatformGemini {
return ""
}
// Note: GeminiOAuthType() already defaults legacy (project_id present) to code_assist.
oauthType := strings.ToLower(strings.TrimSpace(account.GeminiOAuthType()))
rawTier := strings.TrimSpace(account.GeminiTierID())
// Prefer the canonical tier stored in credentials.
if tierID := canonicalGeminiTierIDForOAuthType(oauthType, rawTier); tierID != "" && tierID != GeminiTierGoogleOneUnknown {
return tierID
}
// Fallback defaults when tier_id is missing or unknown.
switch oauthType {
case "google_one":
return GeminiTierGoogleOneFree
case "code_assist":
return GeminiTierGCPStandard
case "ai_studio":
return GeminiTierAIStudioFree
default:
// API Key accounts (type=apikey) have empty oauth_type and are treated as AI Studio.
return GeminiTierAIStudioFree
}
}
func geminiModelClassFromName(model string) geminiModelClass {
name := strings.ToLower(strings.TrimSpace(model))
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {

View File

@@ -487,7 +487,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
case AccountTypeApiKey:
case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
}
// Handle upstream error (mark account status)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// Return appropriate error response
var errType, errMsg string
@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
ApiKey *ApiKey
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
@@ -1020,22 +1026,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
_ = s.usageLogRepo.Create(ctx, usageLog)
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
shouldBill := inserted || err != nil
// Deduct based on billing type
if isSubscriptionBilling {
if cost.TotalCost > 0 {
if shouldBill && cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if shouldBill && cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}

Some files were not shown because too many files have changed in this diff Show More