mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-13 11:24:46 +08:00
Merge PR #142: feat: 多平台网关优化与用户系统增强
主要功能: - Gemini OAuth 配额系统优化:Google One tier 自动推断 - Antigravity 网关增强:Thinking Block 重试、Claude 模型 signature 透传 - 账号调度改进:临时不可调度功能、负载感知调度优化 - 前端用户体验:账号管理界面优化、使用教程改进 冲突解决:保留 handleUpstreamError 的 prefix 参数和日志记录能力, 同时合并 PR 的 thinking block 重试和 model fallback 功能。
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
// Package config provides configuration loading, defaults, and validation.
|
||||
package config
|
||||
|
||||
import (
|
||||
@@ -139,7 +140,7 @@ type GatewayConfig struct {
|
||||
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
|
||||
|
||||
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
|
||||
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
|
||||
InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"`
|
||||
|
||||
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||
@@ -241,7 +242,7 @@ type DefaultConfig struct {
|
||||
AdminPassword string `mapstructure:"admin_password"`
|
||||
UserConcurrency int `mapstructure:"user_concurrency"`
|
||||
UserBalance float64 `mapstructure:"user_balance"`
|
||||
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
|
||||
APIKeyPrefix string `mapstructure:"api_key_prefix"`
|
||||
RateMultiplier float64 `mapstructure:"rate_multiplier"`
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
// Package admin provides HTTP handlers for administrative operations.
|
||||
package admin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
@@ -69,42 +72,45 @@ func NewAccountHandler(
|
||||
|
||||
// CreateAccountRequest represents create account request
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
// UpdateAccountRequest represents update account request
|
||||
// 使用指针类型来区分"未提供"和"设置为0"
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
|
||||
type BulkUpdateAccountsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||
@@ -179,18 +185,40 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
GroupIDs: req.GroupIDs,
|
||||
Name: req.Name,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
GroupIDs: req.GroupIDs,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
// 检查是否为混合渠道错误
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
// 返回特殊错误码要求确认
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
"require_confirmation": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
@@ -213,18 +241,40 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
// 检查是否为混合渠道错误
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
// 返回特殊错误码要求确认
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
"require_confirmation": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
@@ -568,6 +618,9 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
|
||||
hasUpdates := req.Name != "" ||
|
||||
req.ProxyID != nil ||
|
||||
req.Concurrency != nil ||
|
||||
@@ -583,15 +636,16 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
}
|
||||
|
||||
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
|
||||
AccountIDs: req.AccountIDs,
|
||||
Name: req.Name,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
AccountIDs: req.AccountIDs,
|
||||
Name: req.Name,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -781,6 +835,49 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
|
||||
}
|
||||
|
||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||
// GET /api/v1/admin/accounts/:id/temp-unschedulable
|
||||
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
state, err := h.rateLimitService.GetTempUnschedStatus(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if state == nil || state.UntilUnix <= time.Now().Unix() {
|
||||
response.Success(c, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"active": true,
|
||||
"state": state,
|
||||
})
|
||||
}
|
||||
|
||||
// ClearTempUnschedulable handles clearing temporary unschedulable status
|
||||
// DELETE /api/v1/admin/accounts/:id/temp-unschedulable
|
||||
func (h *AccountHandler) ClearTempUnschedulable(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.rateLimitService.ClearTempUnschedulable(c.Request.Context(), accountID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Temp unschedulable cleared successfully"})
|
||||
}
|
||||
|
||||
// GetTodayStats handles getting account today statistics
|
||||
// GET /api/v1/admin/accounts/:id/today-stats
|
||||
func (h *AccountHandler) GetTodayStats(c *gin.Context) {
|
||||
|
||||
@@ -75,8 +75,8 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||
"active_users": stats.ActiveUsers,
|
||||
|
||||
// API Key 统计
|
||||
"total_api_keys": stats.TotalApiKeys,
|
||||
"active_api_keys": stats.ActiveApiKeys,
|
||||
"total_api_keys": stats.TotalAPIKeys,
|
||||
"active_api_keys": stats.ActiveAPIKeys,
|
||||
|
||||
// 账户统计
|
||||
"total_accounts": stats.TotalAccounts,
|
||||
@@ -193,10 +193,10 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetApiKeyUsageTrend handles getting API key usage trend data
|
||||
// GetAPIKeyUsageTrend handles getting API key usage trend data
|
||||
// GET /api/v1/admin/dashboard/api-keys-trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
|
||||
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
||||
func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
limitStr := c.DefaultQuery("limit", "5")
|
||||
@@ -205,7 +205,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
@@ -273,26 +273,26 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
|
||||
type BatchAPIKeysUsageRequest struct {
|
||||
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
|
||||
// GetBatchAPIKeysUsage handles getting usage stats for multiple API keys
|
||||
// POST /api/v1/admin/dashboard/api-keys-usage
|
||||
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
||||
var req BatchApiKeysUsageRequest
|
||||
func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
var req BatchAPIKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
if len(req.APIKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
|
||||
@@ -18,6 +18,7 @@ func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *Gemi
|
||||
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
|
||||
}
|
||||
|
||||
// GetCapabilities returns the Gemini OAuth configuration capabilities.
|
||||
// GET /api/v1/admin/gemini/oauth/capabilities
|
||||
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
|
||||
cfg := h.geminiOAuthService.GetOAuthConfig()
|
||||
@@ -30,6 +31,8 @@ type GeminiGenerateAuthURLRequest struct {
|
||||
// OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id)
|
||||
// 默认为 "code_assist" 以保持向后兼容
|
||||
OAuthType string `json:"oauth_type"`
|
||||
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
|
||||
TierID string `json:"tier_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates Google OAuth authorization URL for Gemini.
|
||||
@@ -54,7 +57,7 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
// Always pass the "hosted" callback URI; the OAuth service may override it depending on
|
||||
// oauth_type and whether the built-in Gemini CLI OAuth client is used.
|
||||
redirectURI := deriveGeminiRedirectURI(c)
|
||||
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType)
|
||||
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType, req.TierID)
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
// Treat missing/invalid OAuth client configuration as a user/config error.
|
||||
@@ -76,6 +79,9 @@ type GeminiExchangeCodeRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
// OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致
|
||||
OAuthType string `json:"oauth_type"`
|
||||
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
|
||||
// This field is optional; when omitted, the server uses the tier stored in the OAuth session.
|
||||
TierID string `json:"tier_id"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens.
|
||||
@@ -103,6 +109,7 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
Code: req.Code,
|
||||
ProxyID: req.ProxyID,
|
||||
OAuthType: oauthType,
|
||||
TierID: req.TierID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
|
||||
@@ -237,9 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outKeys := make([]dto.ApiKey, 0, len(keys))
|
||||
outKeys := make([]dto.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
|
||||
outKeys = append(outKeys, *dto.APIKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -34,26 +34,31 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SmtpHost: settings.SmtpHost,
|
||||
SmtpPort: settings.SmtpPort,
|
||||
SmtpUsername: settings.SmtpUsername,
|
||||
SmtpPassword: settings.SmtpPassword,
|
||||
SmtpFrom: settings.SmtpFrom,
|
||||
SmtpFromName: settings.SmtpFromName,
|
||||
SmtpUseTLS: settings.SmtpUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: settings.TurnstileSecretKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocUrl: settings.DocUrl,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPassword: settings.SMTPPassword,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: settings.TurnstileSecretKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: settings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -64,13 +69,13 @@ type UpdateSettingsRequest struct {
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
// 邮件服务设置
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
SMTPFrom string `json:"smtp_from_email"`
|
||||
SMTPFromName string `json:"smtp_from_name"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
@@ -81,13 +86,20 @@ type UpdateSettingsRequest struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
DocURL string `json:"doc_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
|
||||
FallbackModelOpenAI string `json:"fallback_model_openai"`
|
||||
FallbackModelGemini string `json:"fallback_model_gemini"`
|
||||
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -106,8 +118,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if req.DefaultBalance < 0 {
|
||||
req.DefaultBalance = 0
|
||||
}
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
|
||||
// Turnstile 参数验证
|
||||
@@ -141,26 +153,31 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SmtpHost: req.SmtpHost,
|
||||
SmtpPort: req.SmtpPort,
|
||||
SmtpUsername: req.SmtpUsername,
|
||||
SmtpPassword: req.SmtpPassword,
|
||||
SmtpFrom: req.SmtpFrom,
|
||||
SmtpFromName: req.SmtpFromName,
|
||||
SmtpUseTLS: req.SmtpUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
ApiBaseUrl: req.ApiBaseUrl,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocUrl: req.DocUrl,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
@@ -176,69 +193,74 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SmtpHost: updatedSettings.SmtpHost,
|
||||
SmtpPort: updatedSettings.SmtpPort,
|
||||
SmtpUsername: updatedSettings.SmtpUsername,
|
||||
SmtpPassword: updatedSettings.SmtpPassword,
|
||||
SmtpFrom: updatedSettings.SmtpFrom,
|
||||
SmtpFromName: updatedSettings.SmtpFromName,
|
||||
SmtpUseTLS: updatedSettings.SmtpUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
ApiBaseUrl: updatedSettings.ApiBaseUrl,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocUrl: updatedSettings.DocUrl,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPassword: updatedSettings.SMTPPassword,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
})
|
||||
}
|
||||
|
||||
// TestSmtpRequest 测试SMTP连接请求
|
||||
type TestSmtpRequest struct {
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
// TestSMTPRequest 测试SMTP连接请求
|
||||
type TestSMTPRequest struct {
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// TestSmtpConnection 测试SMTP连接
|
||||
// TestSMTPConnection 测试SMTP连接
|
||||
// POST /api/v1/admin/settings/test-smtp
|
||||
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
|
||||
var req TestSmtpRequest
|
||||
func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
var req TestSMTPRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SmtpPassword
|
||||
password := req.SMTPPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
config := &service.SMTPConfig{
|
||||
Host: req.SMTPHost,
|
||||
Port: req.SMTPPort,
|
||||
Username: req.SMTPUsername,
|
||||
Password: password,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
UseTLS: req.SMTPUseTLS,
|
||||
}
|
||||
|
||||
err := h.emailService.TestSmtpConnectionWithConfig(config)
|
||||
err := h.emailService.TestSMTPConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -250,13 +272,13 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
|
||||
// SendTestEmailRequest 发送测试邮件请求
|
||||
type SendTestEmailRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
SMTPFrom string `json:"smtp_from_email"`
|
||||
SMTPFromName string `json:"smtp_from_name"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// SendTestEmail 发送测试邮件
|
||||
@@ -268,27 +290,27 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SmtpPassword
|
||||
password := req.SMTPPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
config := &service.SMTPConfig{
|
||||
Host: req.SMTPHost,
|
||||
Port: req.SMTPPort,
|
||||
Username: req.SMTPUsername,
|
||||
Password: password,
|
||||
From: req.SmtpFrom,
|
||||
FromName: req.SmtpFromName,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
From: req.SMTPFrom,
|
||||
FromName: req.SMTPFromName,
|
||||
UseTLS: req.SMTPUseTLS,
|
||||
}
|
||||
|
||||
siteName := h.settingService.GetSiteName(c.Request.Context())
|
||||
@@ -333,10 +355,10 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "Test email sent successfully"})
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取管理员 API Key 状态
|
||||
// GetAdminAPIKey 获取管理员 API Key 状态
|
||||
// GET /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
|
||||
func (h *SettingHandler) GetAdminAPIKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminAPIKeyStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -348,10 +370,10 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
|
||||
// RegenerateAdminAPIKey 生成/重新生成管理员 API Key
|
||||
// POST /api/v1/admin/settings/admin-api-key/regenerate
|
||||
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
|
||||
func (h *SettingHandler) RegenerateAdminAPIKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminAPIKey(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -362,10 +384,10 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
// DeleteAdminAPIKey 删除管理员 API Key
|
||||
// DELETE /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
|
||||
func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminAPIKey(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -17,14 +17,14 @@ import (
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.ApiKeyService
|
||||
apiKeyService *service.APIKeyService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
func NewUsageHandler(
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
adminService service.AdminService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
@@ -125,7 +125,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
@@ -207,7 +207,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
if apiKeyID > 0 {
|
||||
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
stats, err := h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -269,9 +269,9 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// SearchApiKeys handles searching API keys by user
|
||||
// SearchAPIKeys handles searching API keys by user
|
||||
// GET /api/v1/admin/usage/search-api-keys
|
||||
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
|
||||
userIDStr := c.Query("user_id")
|
||||
keyword := c.Query("q")
|
||||
|
||||
@@ -285,22 +285,22 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
userID = id
|
||||
}
|
||||
|
||||
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
keys, err := h.apiKeyService.SearchAPIKeys(c.Request.Context(), userID, keyword, 30)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return simplified API key list (only id and name)
|
||||
type SimpleApiKey struct {
|
||||
type SimpleAPIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
result := make([]SimpleApiKey, len(keys))
|
||||
result := make([]SimpleAPIKey, len(keys))
|
||||
for i, k := range keys {
|
||||
result[i] = SimpleApiKey{
|
||||
result[i] = SimpleAPIKey{
|
||||
ID: k.ID,
|
||||
Name: k.Name,
|
||||
UserID: k.UserID,
|
||||
|
||||
@@ -243,9 +243,9 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
out := make([]dto.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
out = append(out, *dto.APIKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package handler provides HTTP request handlers for the application.
|
||||
package handler
|
||||
|
||||
import (
|
||||
@@ -14,11 +15,11 @@ import (
|
||||
|
||||
// APIKeyHandler handles API key-related requests
|
||||
type APIKeyHandler struct {
|
||||
apiKeyService *service.ApiKeyService
|
||||
apiKeyService *service.APIKeyService
|
||||
}
|
||||
|
||||
// NewAPIKeyHandler creates a new APIKeyHandler
|
||||
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
|
||||
func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
|
||||
return &APIKeyHandler{
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
@@ -56,9 +57,9 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
out := make([]dto.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
out = append(out, *dto.APIKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
@@ -90,7 +91,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
}
|
||||
|
||||
// Create handles creating a new API key
|
||||
@@ -108,7 +109,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.CreateApiKeyRequest{
|
||||
svcReq := service.CreateAPIKeyRequest{
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
@@ -119,7 +120,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
}
|
||||
|
||||
// Update handles updating an API key
|
||||
@@ -143,7 +144,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateApiKeyRequest{}
|
||||
svcReq := service.UpdateAPIKeyRequest{}
|
||||
if req.Name != "" {
|
||||
svcReq.Name = &req.Name
|
||||
}
|
||||
@@ -158,7 +159,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
}
|
||||
|
||||
// Delete handles deleting an API key
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package dto provides data transfer objects for HTTP handlers.
|
||||
package dto
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -26,11 +27,11 @@ func UserFromService(u *service.User) *User {
|
||||
return nil
|
||||
}
|
||||
out := UserFromServiceShallow(u)
|
||||
if len(u.ApiKeys) > 0 {
|
||||
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
|
||||
for i := range u.ApiKeys {
|
||||
k := u.ApiKeys[i]
|
||||
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
|
||||
if len(u.APIKeys) > 0 {
|
||||
out.APIKeys = make([]APIKey, 0, len(u.APIKeys))
|
||||
for i := range u.APIKeys {
|
||||
k := u.APIKeys[i]
|
||||
out.APIKeys = append(out.APIKeys, *APIKeyFromService(&k))
|
||||
}
|
||||
}
|
||||
if len(u.Subscriptions) > 0 {
|
||||
@@ -43,11 +44,11 @@ func UserFromService(u *service.User) *User {
|
||||
return out
|
||||
}
|
||||
|
||||
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
|
||||
func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &ApiKey{
|
||||
return &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
@@ -103,28 +104,30 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
return nil
|
||||
}
|
||||
return &Account{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Platform: a.Platform,
|
||||
Type: a.Type,
|
||||
Credentials: a.Credentials,
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
Priority: a.Priority,
|
||||
Status: a.Status,
|
||||
ErrorMessage: a.ErrorMessage,
|
||||
LastUsedAt: a.LastUsedAt,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
Schedulable: a.Schedulable,
|
||||
RateLimitedAt: a.RateLimitedAt,
|
||||
RateLimitResetAt: a.RateLimitResetAt,
|
||||
OverloadUntil: a.OverloadUntil,
|
||||
SessionWindowStart: a.SessionWindowStart,
|
||||
SessionWindowEnd: a.SessionWindowEnd,
|
||||
SessionWindowStatus: a.SessionWindowStatus,
|
||||
GroupIDs: a.GroupIDs,
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Platform: a.Platform,
|
||||
Type: a.Type,
|
||||
Credentials: a.Credentials,
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
Priority: a.Priority,
|
||||
Status: a.Status,
|
||||
ErrorMessage: a.ErrorMessage,
|
||||
LastUsedAt: a.LastUsedAt,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
Schedulable: a.Schedulable,
|
||||
RateLimitedAt: a.RateLimitedAt,
|
||||
RateLimitResetAt: a.RateLimitResetAt,
|
||||
OverloadUntil: a.OverloadUntil,
|
||||
TempUnschedulableUntil: a.TempUnschedulableUntil,
|
||||
TempUnschedulableReason: a.TempUnschedulableReason,
|
||||
SessionWindowStart: a.SessionWindowStart,
|
||||
SessionWindowEnd: a.SessionWindowEnd,
|
||||
SessionWindowStatus: a.SessionWindowStatus,
|
||||
GroupIDs: a.GroupIDs,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,7 +223,7 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
return &UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
ApiKeyID: l.ApiKeyID,
|
||||
APIKeyID: l.APIKeyID,
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
@@ -245,7 +248,7 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
ApiKey: ApiKeyFromService(l.ApiKey),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
Account: AccountFromService(l.Account),
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
|
||||
@@ -5,13 +5,13 @@ type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password,omitempty"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password,omitempty"`
|
||||
SMTPFrom string `json:"smtp_from_email"`
|
||||
SMTPFromName string `json:"smtp_from_name"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
@@ -20,12 +20,19 @@ type SystemSettings struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
DocURL string `json:"doc_url"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
|
||||
FallbackModelOpenAI string `json:"fallback_model_openai"`
|
||||
FallbackModelGemini string `json:"fallback_model_gemini"`
|
||||
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
@@ -36,8 +43,8 @@ type PublicSettings struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
DocURL string `json:"doc_url"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ type User struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
ApiKeys []ApiKey `json:"api_keys,omitempty"`
|
||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
type ApiKey struct {
|
||||
type APIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Key string `json:"key"`
|
||||
@@ -76,6 +76,9 @@ type Account struct {
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
||||
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
||||
|
||||
SessionWindowStart *time.Time `json:"session_window_start"`
|
||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||
SessionWindowStatus string `json:"session_window_status"`
|
||||
@@ -136,7 +139,7 @@ type RedeemCode struct {
|
||||
type UsageLog struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
@@ -168,7 +171,7 @@ type UsageLog struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
ApiKey *ApiKey `json:"api_key,omitempty"`
|
||||
APIKey *APIKey `json:"api_key,omitempty"`
|
||||
Account *Account `json:"account,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||
|
||||
@@ -53,7 +53,7 @@ func NewGatewayHandler(
|
||||
// POST /v1/messages
|
||||
func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -259,7 +259,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
@@ -383,7 +383,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
@@ -400,7 +400,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// Returns models based on account configurations (model_mapping whitelist)
|
||||
// Falls back to default models if no whitelist is configured
|
||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
apiKey, _ := middleware2.GetApiKeyFromContext(c)
|
||||
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
|
||||
|
||||
var groupID *int64
|
||||
var platform string
|
||||
@@ -458,7 +458,7 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
|
||||
// Usage handles getting account balance for CC Switch integration
|
||||
// GET /v1/usage
|
||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -628,7 +628,7 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
|
||||
// 特点:校验订阅/余额,但不计算并发、不记录使用量
|
||||
func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
@@ -67,7 +67,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
// GeminiV1BetaGetModel proxies:
|
||||
// GET /v1beta/models/{model}
|
||||
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
@@ -120,7 +120,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
// POST /v1beta/models/{model}:generateContent
|
||||
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
|
||||
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
@@ -299,7 +299,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
|
||||
@@ -41,7 +41,7 @@ func NewOpenAIGatewayHandler(
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -235,7 +235,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
|
||||
@@ -39,9 +39,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocUrl: settings.DocUrl,
|
||||
DocURL: settings.DocURL,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -18,11 +18,11 @@ import (
|
||||
// UsageHandler handles usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.ApiKeyService
|
||||
apiKeyService *service.APIKeyService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new UsageHandler
|
||||
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.APIKeyService) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
@@ -111,7 +111,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: subject.UserID, // Always filter by current user for security
|
||||
ApiKeyID: apiKeyID,
|
||||
APIKeyID: apiKeyID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
@@ -235,7 +235,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
var stats *service.UsageStats
|
||||
var err error
|
||||
if apiKeyID > 0 {
|
||||
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
stats, err = h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
} else {
|
||||
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
|
||||
}
|
||||
@@ -346,49 +346,49 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// BatchApiKeysUsageRequest represents the request for batch API keys usage
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
// BatchAPIKeysUsageRequest represents the request for batch API keys usage
|
||||
type BatchAPIKeysUsageRequest struct {
|
||||
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
|
||||
// DashboardAPIKeysUsage handles getting usage stats for user's own API keys
|
||||
// POST /api/v1/usage/dashboard/api-keys-usage
|
||||
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchApiKeysUsageRequest
|
||||
var req BatchAPIKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
if len(req.APIKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
// Limit the number of API key IDs to prevent SQL parameter overflow
|
||||
if len(req.ApiKeyIDs) > 100 {
|
||||
if len(req.APIKeyIDs) > 100 {
|
||||
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
|
||||
return
|
||||
}
|
||||
|
||||
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
|
||||
validAPIKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.APIKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
if len(validAPIKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package antigravity provides a client for the Antigravity API.
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
@@ -57,6 +58,29 @@ type TierInfo struct {
|
||||
Description string `json:"description"` // 描述
|
||||
}
|
||||
|
||||
// UnmarshalJSON supports both legacy string tiers and object tiers.
|
||||
func (t *TierInfo) UnmarshalJSON(data []byte) error {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
if data[0] == '"' {
|
||||
var id string
|
||||
if err := json.Unmarshal(data, &id); err != nil {
|
||||
return err
|
||||
}
|
||||
t.ID = id
|
||||
return nil
|
||||
}
|
||||
type alias TierInfo
|
||||
var decoded alias
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
return err
|
||||
}
|
||||
*t = TierInfo(decoded)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IneligibleTier 不符合条件的层级信息
|
||||
type IneligibleTier struct {
|
||||
Tier *TierInfo `json:"tier,omitempty"`
|
||||
|
||||
@@ -240,10 +240,13 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
ID: block.ID,
|
||||
},
|
||||
}
|
||||
// 只有 Gemini 模型使用 dummy signature
|
||||
// Claude 模型不设置 signature(避免验证问题)
|
||||
// tool_use 的 signature 处理:
|
||||
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验)
|
||||
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路)
|
||||
if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
|
||||
part.ThoughtSignature = block.Signature
|
||||
}
|
||||
parts = append(parts, part)
|
||||
|
||||
|
||||
@@ -15,26 +15,26 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Claude model - skip thinking block without signature",
|
||||
name: "Claude model - drop thinking without signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: false,
|
||||
expectedParts: 2, // 只有两个text block
|
||||
description: "Claude模型应该跳过无signature的thinking block",
|
||||
expectedParts: 2, // thinking 内容被丢弃
|
||||
description: "Claude模型应丢弃无signature的thinking block内容",
|
||||
},
|
||||
{
|
||||
name: "Claude model - keep thinking block with signature",
|
||||
name: "Claude model - preserve thinking block with signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "sig_real_123"},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: false,
|
||||
expectedParts: 3, // 三个block都保留
|
||||
description: "Claude模型应该保留有signature的thinking block",
|
||||
expectedParts: 3,
|
||||
description: "Claude模型应透传带 signature 的 thinking block(用于 Vertex 签名链路)",
|
||||
},
|
||||
{
|
||||
name: "Gemini model - use dummy signature",
|
||||
@@ -61,10 +61,64 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
if len(parts) != tt.expectedParts {
|
||||
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
|
||||
}
|
||||
|
||||
switch tt.name {
|
||||
case "Claude model - preserve thinking block with signature":
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
if !parts[1].Thought || parts[1].ThoughtSignature != "sig_real_123" {
|
||||
t.Fatalf("expected thought part with signature sig_real_123, got thought=%v signature=%q",
|
||||
parts[1].Thought, parts[1].ThoughtSignature)
|
||||
}
|
||||
case "Gemini model - use dummy signature":
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
if !parts[1].Thought || parts[1].ThoughtSignature != dummyThoughtSignature {
|
||||
t.Fatalf("expected dummy thought signature, got thought=%v signature=%q",
|
||||
parts[1].Thought, parts[1].ThoughtSignature)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
||||
content := `[
|
||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
|
||||
]`
|
||||
|
||||
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, err := buildParts(json.RawMessage(content), toolIDToName, true)
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
if parts[0].ThoughtSignature != dummyThoughtSignature {
|
||||
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Claude model - preserve valid signature for tool_use", func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, err := buildParts(json.RawMessage(content), toolIDToName, false)
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
// Claude 模型应透传有效的 signature(Vertex/Google 需要完整签名链路)
|
||||
if parts[0].ThoughtSignature != "sig_tool_abc" {
|
||||
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
|
||||
func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package claude provides constants and helpers for Claude API integration.
|
||||
package claude
|
||||
|
||||
// Claude Code 客户端相关常量
|
||||
@@ -16,13 +17,13 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
|
||||
// Claude Code 客户端默认请求头
|
||||
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
||||
var DefaultHeaders = map[string]string{
|
||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||
"X-Stainless-Lang": "js",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package errors provides application error types and helpers.
|
||||
// nolint:mnd
|
||||
package errors
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package gemini
|
||||
|
||||
// This package provides minimal fallback model metadata for Gemini native endpoints.
|
||||
// Package gemini provides minimal fallback model metadata for Gemini native endpoints.
|
||||
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
||||
package gemini
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
|
||||
type LoadCodeAssistRequest struct {
|
||||
Metadata LoadCodeAssistMetadata `json:"metadata"`
|
||||
@@ -11,12 +16,51 @@ type LoadCodeAssistMetadata struct {
|
||||
PluginType string `json:"pluginType"`
|
||||
}
|
||||
|
||||
type TierInfo struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON supports both legacy string tiers and object tiers.
|
||||
func (t *TierInfo) UnmarshalJSON(data []byte) error {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
if data[0] == '"' {
|
||||
var id string
|
||||
if err := json.Unmarshal(data, &id); err != nil {
|
||||
return err
|
||||
}
|
||||
t.ID = id
|
||||
return nil
|
||||
}
|
||||
type alias TierInfo
|
||||
var decoded alias
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
return err
|
||||
}
|
||||
*t = TierInfo(decoded)
|
||||
return nil
|
||||
}
|
||||
|
||||
type LoadCodeAssistResponse struct {
|
||||
CurrentTier string `json:"currentTier,omitempty"`
|
||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
|
||||
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
|
||||
}
|
||||
|
||||
// GetTier extracts tier ID, prioritizing paidTier over currentTier
|
||||
func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
if r.PaidTier != nil && r.PaidTier.ID != "" {
|
||||
return r.PaidTier.ID
|
||||
}
|
||||
if r.CurrentTier != nil {
|
||||
return r.CurrentTier.ID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type AllowedTier struct {
|
||||
ID string `json:"id"`
|
||||
IsDefault bool `json:"isDefault,omitempty"`
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package geminicli provides helpers for interacting with Gemini CLI tools.
|
||||
package geminicli
|
||||
|
||||
import "time"
|
||||
@@ -26,6 +27,12 @@ const (
|
||||
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
|
||||
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
|
||||
|
||||
// DefaultScopes for Google One (personal Google accounts with Gemini access)
|
||||
// Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client,
|
||||
// Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client
|
||||
// cannot request restricted scopes like generative-language.retriever or drive.readonly.
|
||||
DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
|
||||
|
||||
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
|
||||
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
|
||||
|
||||
|
||||
@@ -11,11 +11,12 @@ type Model struct {
|
||||
|
||||
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
const DefaultTestModel = "gemini-3-pro-preview"
|
||||
const DefaultTestModel = "gemini-2.0-flash"
|
||||
|
||||
@@ -19,13 +19,17 @@ type OAuthConfig struct {
|
||||
}
|
||||
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
// TierID is a user-selected fallback tier.
|
||||
// For oauth types that support auto detection (google_one/code_assist), the server will prefer
|
||||
// the detected tier and fall back to TierID when detection fails.
|
||||
TierID string `json:"tier_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type SessionStore struct {
|
||||
@@ -172,23 +176,32 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
|
||||
|
||||
if effective.Scopes == "" {
|
||||
// Use different default scopes based on OAuth type
|
||||
if oauthType == "ai_studio" {
|
||||
switch oauthType {
|
||||
case "ai_studio":
|
||||
// Built-in client can't request some AI Studio scopes (notably generative-language).
|
||||
if isBuiltinClient {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = DefaultAIStudioScopes
|
||||
}
|
||||
} else {
|
||||
case "google_one":
|
||||
// Google One uses built-in Gemini CLI client (same as code_assist)
|
||||
// Built-in client can't request restricted scopes like generative-language.retriever
|
||||
if isBuiltinClient {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = DefaultGoogleOneScopes
|
||||
}
|
||||
default:
|
||||
// Default to Code Assist scopes
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
}
|
||||
} else if oauthType == "ai_studio" && isBuiltinClient {
|
||||
} else if (oauthType == "ai_studio" || oauthType == "google_one") && isBuiltinClient {
|
||||
// If user overrides scopes while still using the built-in client, strip restricted scopes.
|
||||
parts := strings.Fields(effective.Scopes)
|
||||
filtered := make([]string, 0, len(parts))
|
||||
for _, s := range parts {
|
||||
if strings.Contains(s, "generative-language") {
|
||||
if hasRestrictedScope(s) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, s)
|
||||
@@ -214,6 +227,11 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
|
||||
return effective, nil
|
||||
}
|
||||
|
||||
func hasRestrictedScope(scope string) bool {
|
||||
return strings.HasPrefix(scope, "https://www.googleapis.com/auth/generative-language") ||
|
||||
strings.HasPrefix(scope, "https://www.googleapis.com/auth/drive")
|
||||
}
|
||||
|
||||
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
|
||||
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
|
||||
if err != nil {
|
||||
|
||||
113
backend/internal/pkg/geminicli/oauth_test.go
Normal file
113
backend/internal/pkg/geminicli/oauth_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input OAuthConfig
|
||||
oauthType string
|
||||
wantClientID string
|
||||
wantScopes string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Google One with built-in client (empty config)",
|
||||
input: OAuthConfig{},
|
||||
oauthType: "google_one",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
wantScopes: DefaultCodeAssistScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with custom client",
|
||||
input: OAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantScopes: DefaultGoogleOneScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with built-in client and custom scopes (should filter restricted scopes)",
|
||||
input: OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
wantScopes: "https://www.googleapis.com/auth/cloud-platform",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with built-in client and only restricted scopes (should fallback to default)",
|
||||
input: OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
wantScopes: DefaultCodeAssistScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Code Assist with built-in client",
|
||||
input: OAuthConfig{},
|
||||
oauthType: "code_assist",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
wantScopes: DefaultCodeAssistScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := EffectiveOAuthConfig(tt.input, tt.oauthType)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("EffectiveOAuthConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if got.ClientID != tt.wantClientID {
|
||||
t.Errorf("EffectiveOAuthConfig() ClientID = %v, want %v", got.ClientID, tt.wantClientID)
|
||||
}
|
||||
if got.Scopes != tt.wantScopes {
|
||||
t.Errorf("EffectiveOAuthConfig() Scopes = %v, want %v", got.Scopes, tt.wantScopes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
|
||||
// Test that Google One with built-in client filters out restricted scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile",
|
||||
}, "google_one")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
|
||||
// Should only contain cloud-platform, userinfo.email, and userinfo.profile
|
||||
// Should NOT contain generative-language or drive scopes
|
||||
if strings.Contains(cfg.Scopes, "generative-language") {
|
||||
t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes)
|
||||
}
|
||||
if strings.Contains(cfg.Scopes, "drive") {
|
||||
t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.email") {
|
||||
t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.profile") {
|
||||
t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package googleapi provides helpers for Google-style API responses.
|
||||
package googleapi
|
||||
|
||||
import "net/http"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package oauth provides helpers for OAuth flows used by this service.
|
||||
package oauth
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package openai provides helpers and types for OpenAI API integration.
|
||||
package openai
|
||||
|
||||
import _ "embed"
|
||||
|
||||
@@ -327,7 +327,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ExtractUserInfo extracts user information from ID Token claims
|
||||
// UserInfo represents user information extracted from ID Token claims.
|
||||
type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package pagination provides types and helpers for paginated responses.
|
||||
package pagination
|
||||
|
||||
// PaginationParams 分页参数
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package response provides standardized HTTP response helpers.
|
||||
package response
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package sysutil provides system-level utilities for process management.
|
||||
package sysutil
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package usagestats provides types for usage statistics and reporting.
|
||||
package usagestats
|
||||
|
||||
import "time"
|
||||
@@ -10,8 +11,8 @@ type DashboardStats struct {
|
||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
TotalAPIKeys int64 `json:"total_api_keys"`
|
||||
ActiveAPIKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
|
||||
// 账户统计
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
@@ -82,10 +83,10 @@ type UserUsageTrendPoint struct {
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint struct {
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
@@ -94,8 +95,8 @@ type ApiKeyUsageTrendPoint struct {
|
||||
// UserDashboardStats 用户仪表盘统计
|
||||
type UserDashboardStats struct {
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"`
|
||||
TotalAPIKeys int64 `json:"total_api_keys"`
|
||||
ActiveAPIKeys int64 `json:"active_api_keys"`
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
@@ -128,7 +129,7 @@ type UserDashboardStats struct {
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
type UsageLogFilters struct {
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
APIKeyID int64
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
@@ -157,9 +158,9 @@ type BatchUserUsageStats struct {
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats struct {
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
@@ -43,6 +43,11 @@ type accountRepository struct {
|
||||
sql sqlExecutor // 原生 SQL 执行接口
|
||||
}
|
||||
|
||||
type tempUnschedSnapshot struct {
|
||||
until *time.Time
|
||||
reason string
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
|
||||
@@ -165,6 +170,11 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
|
||||
accountIDs = append(accountIDs, acc.ID)
|
||||
}
|
||||
|
||||
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -191,6 +201,10 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
|
||||
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
|
||||
out.AccountGroups = ags
|
||||
}
|
||||
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
|
||||
out.TempUnschedulableUntil = snap.until
|
||||
out.TempUnschedulableReason = snap.reason
|
||||
}
|
||||
outByID[entAcc.ID] = out
|
||||
}
|
||||
|
||||
@@ -550,6 +564,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco
|
||||
Where(
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -575,6 +590,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf
|
||||
dbaccount.PlatformEQ(platform),
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -607,6 +623,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
|
||||
dbaccount.PlatformIn(platforms...),
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -648,6 +665,31 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = $1,
|
||||
temp_unschedulable_reason = $2,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
AND deleted_at IS NULL
|
||||
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
|
||||
`, until, reason, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND deleted_at IS NULL
|
||||
`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -808,6 +850,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in
|
||||
now := time.Now()
|
||||
preds = append(preds,
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
)
|
||||
@@ -869,6 +912,10 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -894,12 +941,68 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
|
||||
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
|
||||
out.AccountGroups = ags
|
||||
}
|
||||
if snap, ok := tempUnschedMap[acc.ID]; ok {
|
||||
out.TempUnschedulableUntil = snap.until
|
||||
out.TempUnschedulableReason = snap.reason
|
||||
}
|
||||
outAccounts = append(outAccounts, *out)
|
||||
}
|
||||
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
func tempUnschedulablePredicate() dbpredicate.Account {
|
||||
return dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.Or(
|
||||
entsql.IsNull(col),
|
||||
entsql.LTE(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
|
||||
out := make(map[int64]tempUnschedSnapshot)
|
||||
if len(accountIDs) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
|
||||
FROM accounts
|
||||
WHERE id = ANY($1)
|
||||
`, pq.Array(accountIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var until sql.NullTime
|
||||
var reason sql.NullString
|
||||
if err := rows.Scan(&id, &until, &reason); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var untilPtr *time.Time
|
||||
if until.Valid {
|
||||
tmp := until.Time
|
||||
untilPtr = &tmp
|
||||
}
|
||||
if reason.Valid {
|
||||
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
|
||||
} else {
|
||||
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
|
||||
proxyMap := make(map[int64]*service.Proxy)
|
||||
if len(proxyIDs) == 0 {
|
||||
|
||||
@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
name: "filter_by_type",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
|
||||
},
|
||||
accType: service.AccountTypeApiKey,
|
||||
accType: service.AccountTypeAPIKey,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
|
||||
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := NewApiKeyRepository(entClient)
|
||||
apiKeyRepo := NewAPIKeyRepository(entClient)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
||||
}
|
||||
require.NoError(t, userRepo.Create(ctx, u))
|
||||
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
|
||||
Name: "test key",
|
||||
|
||||
@@ -24,7 +24,7 @@ type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
|
||||
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,17 +16,17 @@ type apiKeyRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
||||
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
created, err := r.client.ApiKey.Create().
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
||||
created, err := r.client.APIKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
|
||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
||||
// 相比 GetByID,此方法性能更优,因为:
|
||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
||||
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
|
||||
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
m, err := r.activeQuery().
|
||||
@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrApiKeyNotFound
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return m.UserID, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||
// 则会更新已删除的记录。
|
||||
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
|
||||
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
|
||||
now := time.Now()
|
||||
builder := r.client.ApiKey.Update().
|
||||
builder := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
|
||||
}
|
||||
if affected == 0 {
|
||||
// 更新影响行数为 0,说明记录不存在或已被软删除。
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
|
||||
@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.ApiKey.Update().
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
exists, err := r.client.ApiKey.Query().
|
||||
exists, err := r.client.APIKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Exist(mixins.SkipSoftDelete(ctx))
|
||||
if err != nil {
|
||||
@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids, err := r.client.ApiKey.Query().
|
||||
ids, err := r.client.APIKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
q := r.activeQuery()
|
||||
if userID > 0 {
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
n, err := r.client.ApiKey.Update().
|
||||
n, err := r.client.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.ApiKey{
|
||||
out := &service.APIKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
|
||||
@@ -12,30 +12,30 @@ import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyRepoSuite struct {
|
||||
type APIKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
func (s *APIKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
|
||||
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyRepoSuite))
|
||||
func TestAPIKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(APIKeyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
func (s *APIKeyRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser("create@test.com")
|
||||
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
@@ -51,16 +51,16 @@ func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
s.Require().Equal("sk-create-test", got.Key)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
func (s *APIKeyRepoSuite) TestGetByKey() {
|
||||
user := s.mustCreateUser("getbykey@test.com")
|
||||
group := s.mustCreateGroup("g-key")
|
||||
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
@@ -78,16 +78,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||
s.Require().Error(err, "expected error for non-existent key")
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
func (s *APIKeyRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser("update@test.com")
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
@@ -108,10 +108,10 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
s.Require().Equal(service.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := s.mustCreateUser("cleargroup@test.com")
|
||||
group := s.mustCreateGroup("g-clear")
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
@@ -131,9 +131,9 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
func (s *APIKeyRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser("delete@test.com")
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
@@ -150,7 +150,7 @@ func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
func (s *APIKeyRepoSuite) TestListByUserID() {
|
||||
user := s.mustCreateUser("listbyuser@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
@@ -161,7 +161,7 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := s.mustCreateUser("paging@test.com")
|
||||
for i := 0; i < 5; i++ {
|
||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
@@ -174,7 +174,7 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
s.Require().Equal(3, page.Pages)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
func (s *APIKeyRepoSuite) TestCountByUserID() {
|
||||
user := s.mustCreateUser("count@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
|
||||
@@ -186,7 +186,7 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestListByGroupID() {
|
||||
user := s.mustCreateUser("listbygroup@test.com")
|
||||
group := s.mustCreateGroup("g-list")
|
||||
|
||||
@@ -202,7 +202,7 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
s.Require().NotNil(keys[0].User)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestCountByGroupID() {
|
||||
user := s.mustCreateUser("countgroup@test.com")
|
||||
group := s.mustCreateGroup("g-count")
|
||||
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
@@ -214,7 +214,7 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
func (s *APIKeyRepoSuite) TestExistsByKey() {
|
||||
user := s.mustCreateUser("exists@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
|
||||
|
||||
@@ -227,41 +227,41 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- SearchApiKeys ---
|
||||
// --- SearchAPIKeys ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
|
||||
user := s.mustCreateUser("search@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Contains(found[0].Name, "Production")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
|
||||
user := s.mustCreateUser("searchnokw@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 2)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
|
||||
user := s.mustCreateUser("searchnouid@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 1)
|
||||
}
|
||||
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := s.mustCreateUser("cleargrp@test.com")
|
||||
group := s.mustCreateGroup("g-clear-bulk")
|
||||
|
||||
@@ -284,7 +284,7 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := s.mustCreateUser("k@example.com")
|
||||
group := s.mustCreateGroup("g-k")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
@@ -320,8 +320,8 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists, "expected key to exist")
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
@@ -346,7 +346,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
@@ -359,7 +359,7 @@ func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
@@ -370,10 +370,10 @@ func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
|
||||
func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.APIKey {
|
||||
s.T().Helper()
|
||||
|
||||
k := &service.ApiKey{
|
||||
k := &service.APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: name,
|
||||
|
||||
@@ -5,28 +5,20 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
client *claudeOAuthService
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// requestCapture holds captured request data for assertions in the main goroutine.
|
||||
type requestCapture struct {
|
||||
path string
|
||||
@@ -37,6 +29,12 @@ type requestCapture struct {
|
||||
contentType string
|
||||
}
|
||||
|
||||
func newTestReqClient(rt http.RoundTripper) *req.Client {
|
||||
c := req.C()
|
||||
c.GetClient().Transport = rt
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -83,17 +81,17 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.cookies = r.Cookies()
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
|
||||
|
||||
@@ -158,20 +156,20 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.method = r.Method
|
||||
captured.cookies = r.Cookies()
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||
|
||||
@@ -266,19 +264,19 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
s.client.tokenURL = "http://in-process/token"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
|
||||
|
||||
@@ -362,19 +360,19 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
s.client.tokenURL = "http://in-process/token"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ type usageRequestCapture struct {
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
var captured usageRequestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.authorization = r.Header.Get("Authorization")
|
||||
captured.anthropicBeta = r.Header.Get("anthropic-beta")
|
||||
|
||||
@@ -59,7 +59,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "nope")
|
||||
}))
|
||||
@@ -73,7 +73,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
@@ -86,7 +86,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Never respond - simulate slow server
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
@@ -309,7 +309,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package infrastructure 提供应用程序的基础设施层组件。
|
||||
// Package repository 提供应用程序的基础设施层组件。
|
||||
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
|
||||
package repository
|
||||
|
||||
|
||||
@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *
|
||||
return a
|
||||
}
|
||||
|
||||
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
|
||||
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -257,7 +257,7 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *se
|
||||
k.Name = "default"
|
||||
}
|
||||
|
||||
create := client.ApiKey.Create().
|
||||
create := client.APIKey.Create().
|
||||
SetUserID(k.UserID).
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
|
||||
@@ -30,6 +30,7 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
|
||||
|
||||
// Use different OAuth clients based on oauthType:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public)
|
||||
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client
|
||||
// - ai_studio: requires a user-provided OAuth client
|
||||
oauthCfgInput := geminicli.OAuthConfig{
|
||||
ClientID: c.cfg.Gemini.OAuth.ClientID,
|
||||
|
||||
@@ -49,7 +49,7 @@ func (s *GitHubReleaseServiceSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", "100")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
|
||||
@@ -68,7 +68,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
@@ -95,7 +95,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
@@ -123,7 +123,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
@@ -140,7 +140,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("sum"))
|
||||
}))
|
||||
@@ -155,7 +155,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
@@ -168,7 +168,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
@@ -195,7 +195,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("content"))
|
||||
}))
|
||||
@@ -233,7 +233,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
]
|
||||
}`
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
|
||||
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
|
||||
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
|
||||
@@ -258,7 +258,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
@@ -274,7 +274,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("not valid json"))
|
||||
}))
|
||||
@@ -290,7 +290,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
@@ -308,7 +308,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
|
||||
@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
|
||||
// 2. Clear group_id for api keys bound to this group.
|
||||
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
|
||||
// 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.ApiKey.Update().
|
||||
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx); err != nil {
|
||||
|
||||
@@ -3,7 +3,6 @@ package repository
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -93,7 +92,7 @@ func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
|
||||
// 验证空代理 URL 时请求直接发送到目标服务器
|
||||
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
|
||||
// 创建模拟上游服务器
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
@@ -115,7 +114,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
|
||||
// 用于接收代理请求的通道
|
||||
seen := make(chan string, 1)
|
||||
// 创建模拟代理服务器
|
||||
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI // 记录请求 URI
|
||||
_, _ = io.WriteString(w, "proxied")
|
||||
}))
|
||||
@@ -145,7 +144,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
|
||||
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
|
||||
// 验证空字符串代理等同于直连
|
||||
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct-empty")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
|
||||
63
backend/internal/repository/inprocess_transport_test.go
Normal file
63
backend/internal/repository/inprocess_transport_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
|
||||
// It captures the request body (if any) and then rewinds it before invoking the handler.
|
||||
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
|
||||
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
var body []byte
|
||||
if r.Body != nil {
|
||||
body, _ = io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
if capture != nil {
|
||||
capture(r, body)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, r)
|
||||
return rec.Result(), nil
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
canListenOnce sync.Once
|
||||
canListen bool
|
||||
canListenErr error
|
||||
)
|
||||
|
||||
func localListenerAvailable() bool {
|
||||
canListenOnce.Do(func() {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
canListenErr = err
|
||||
canListen = false
|
||||
return
|
||||
}
|
||||
_ = ln.Close()
|
||||
canListen = true
|
||||
})
|
||||
return canListen
|
||||
}
|
||||
|
||||
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
|
||||
tb.Helper()
|
||||
if !localListenerAvailable() {
|
||||
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
|
||||
}
|
||||
return httptest.NewServer(handler)
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func (s *OpenAIOAuthServiceSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
s.srv = newLocalTestServer(s.T(), handler)
|
||||
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func (s *PricingServiceSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
s.srv = newLocalTestServer(s.T(), handler)
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
|
||||
|
||||
@@ -31,7 +31,7 @@ func (s *ProxyProbeServiceSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
|
||||
s.proxySrv = httptest.NewServer(handler)
|
||||
s.proxySrv = newLocalTestServer(s.T(), handler)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
|
||||
|
||||
@@ -41,8 +41,8 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
|
||||
Name: "soft-delete",
|
||||
@@ -53,13 +53,13 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
_, err := repo.GetByID(ctx, key.ID)
|
||||
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
|
||||
|
||||
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
|
||||
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
|
||||
require.Error(t, err, "default ent query should not see soft-deleted rows")
|
||||
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
|
||||
|
||||
got, err := client.ApiKey.Query().
|
||||
got, err := client.APIKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
|
||||
@@ -73,8 +73,8 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
|
||||
Name: "soft-delete2",
|
||||
@@ -93,8 +93,8 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
|
||||
Name: "soft-delete3",
|
||||
@@ -105,10 +105,10 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
|
||||
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
|
||||
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "hard delete")
|
||||
|
||||
_, err = client.ApiKey.Query().
|
||||
_, err = client.APIKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
|
||||
|
||||
91
backend/internal/repository/temp_unsched_cache.go
Normal file
91
backend/internal/repository/temp_unsched_cache.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const tempUnschedPrefix = "temp_unsched:account:"
|
||||
|
||||
var tempUnschedSetScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local new_until = tonumber(ARGV[1])
|
||||
local new_value = ARGV[2]
|
||||
local new_ttl = tonumber(ARGV[3])
|
||||
|
||||
local existing = redis.call('GET', key)
|
||||
if existing then
|
||||
local ok, existing_data = pcall(cjson.decode, existing)
|
||||
if ok and existing_data and existing_data.until_unix then
|
||||
local existing_until = tonumber(existing_data.until_unix)
|
||||
if existing_until and new_until <= existing_until then
|
||||
return 0
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call('SET', key, new_value, 'EX', new_ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
type tempUnschedCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewTempUnschedCache(rdb *redis.Client) service.TempUnschedCache {
|
||||
return &tempUnschedCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// SetTempUnsched 设置临时不可调度状态(只延长不缩短)
|
||||
func (c *tempUnschedCache) SetTempUnsched(ctx context.Context, accountID int64, state *service.TempUnschedState) error {
|
||||
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
|
||||
|
||||
stateJSON, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal state: %w", err)
|
||||
}
|
||||
|
||||
ttl := time.Until(time.Unix(state.UntilUnix, 0))
|
||||
if ttl <= 0 {
|
||||
return nil // 已过期,不设置
|
||||
}
|
||||
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
if ttlSeconds < 1 {
|
||||
ttlSeconds = 1
|
||||
}
|
||||
|
||||
_, err = tempUnschedSetScript.Run(ctx, c.rdb, []string{key}, state.UntilUnix, string(stateJSON), ttlSeconds).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTempUnsched 获取临时不可调度状态
|
||||
func (c *tempUnschedCache) GetTempUnsched(ctx context.Context, accountID int64) (*service.TempUnschedState, error) {
|
||||
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
|
||||
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var state service.TempUnschedState
|
||||
if err := json.Unmarshal([]byte(val), &state); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal state: %w", err)
|
||||
}
|
||||
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
// DeleteTempUnsched 删除临时不可调度状态
|
||||
func (c *tempUnschedCache) DeleteTempUnsched(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -3,9 +3,9 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
type TurnstileServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
verifier *turnstileVerifier
|
||||
received chan url.Values
|
||||
}
|
||||
@@ -31,20 +30,15 @@ func (s *TurnstileServiceSuite) SetupTest() {
|
||||
s.verifier = verifier
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
func (s *TurnstileServiceSuite) setupTransport(handler http.HandlerFunc) {
|
||||
s.verifier.verifyURL = "http://in-process/turnstile"
|
||||
s.verifier.httpClient = &http.Client{
|
||||
Transport: newInProcessTransport(handler, nil),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
s.verifier.verifyURL = s.srv.URL
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture form data in main goroutine context later
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(body))
|
||||
@@ -72,7 +66,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
|
||||
var contentType string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
contentType = r.Header.Get("Content-Type")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||
@@ -84,7 +78,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(body))
|
||||
s.received <- values
|
||||
@@ -105,15 +99,19 @@ func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.srv.Close()
|
||||
s.verifier.verifyURL = "http://in-process/turnstile"
|
||||
s.verifier.httpClient = &http.Client{
|
||||
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("dial failed")
|
||||
}),
|
||||
}
|
||||
|
||||
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.Error(s.T(), err, "expected error when server is closed")
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-valid-json")
|
||||
}))
|
||||
@@ -123,7 +121,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
|
||||
Success: false,
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -60,9 +61,16 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
|
||||
return requestCount / 5, tokenCount / 5, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
|
||||
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
if log == nil {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。
|
||||
// 无事务时回退到默认的 *sql.DB 执行器。
|
||||
sqlq := r.sql
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
sqlq = tx.Client()
|
||||
}
|
||||
|
||||
createdAt := log.CreatedAt
|
||||
@@ -70,6 +78,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
createdAt = time.Now()
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(log.RequestID)
|
||||
log.RequestID = requestID
|
||||
|
||||
rateMultiplier := log.RateMultiplier
|
||||
|
||||
query := `
|
||||
@@ -107,6 +118,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
`
|
||||
|
||||
@@ -115,11 +127,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration := nullInt(log.DurationMs)
|
||||
firstToken := nullInt(log.FirstTokenMs)
|
||||
|
||||
var requestIDArg any
|
||||
if requestID != "" {
|
||||
requestIDArg = requestID
|
||||
}
|
||||
|
||||
args := []any{
|
||||
log.UserID,
|
||||
log.ApiKeyID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
requestIDArg,
|
||||
log.Model,
|
||||
groupID,
|
||||
subscriptionID,
|
||||
@@ -142,11 +159,20 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
firstToken,
|
||||
createdAt,
|
||||
}
|
||||
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
return err
|
||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
|
||||
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
|
||||
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
||||
return false, err
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
return false, nil
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||||
@@ -183,7 +209,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
||||
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
|
||||
}
|
||||
|
||||
@@ -270,8 +296,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
r.sql,
|
||||
apiKeyStatsQuery,
|
||||
[]any{service.StatusActive},
|
||||
&stats.TotalApiKeys,
|
||||
&stats.ActiveApiKeys,
|
||||
&stats.TotalAPIKeys,
|
||||
&stats.ActiveAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -418,8 +444,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||||
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||||
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
@@ -623,7 +649,7 @@ func resolveUsageStatsTimezone() string {
|
||||
return "UTC"
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
@@ -709,11 +735,11 @@ type ModelStat = usagestats.ModelStat
|
||||
// UserUsageTrendPoint represents user usage trend data point
|
||||
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
|
||||
|
||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
|
||||
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -755,10 +781,10 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]ApiKeyUsageTrendPoint, 0)
|
||||
results = make([]APIKeyUsageTrendPoint, 0)
|
||||
for rows.Next() {
|
||||
var row ApiKeyUsageTrendPoint
|
||||
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||
var row APIKeyUsageTrendPoint
|
||||
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
@@ -844,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
|
||||
[]any{userID},
|
||||
&stats.TotalApiKeys,
|
||||
&stats.TotalAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -853,7 +879,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
|
||||
[]any{userID, service.StatusActive},
|
||||
&stats.ActiveApiKeys,
|
||||
&stats.ActiveAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1023,9 +1049,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||||
args = append(args, filters.UserID)
|
||||
}
|
||||
if filters.ApiKeyID > 0 {
|
||||
if filters.APIKeyID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
|
||||
args = append(args, filters.ApiKeyID)
|
||||
args = append(args, filters.APIKeyID)
|
||||
}
|
||||
if filters.AccountID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
|
||||
@@ -1145,18 +1171,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||||
|
||||
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchApiKeyUsageStats)
|
||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, id := range apiKeyIDs {
|
||||
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
|
||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||
}
|
||||
|
||||
query := `
|
||||
@@ -1582,7 +1608,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
|
||||
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1603,8 +1629,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
|
||||
if user, ok := users[logs[i].UserID]; ok {
|
||||
logs[i].User = user
|
||||
}
|
||||
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
|
||||
logs[i].ApiKey = key
|
||||
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
|
||||
logs[i].APIKey = key
|
||||
}
|
||||
if acc, ok := accounts[logs[i].AccountID]; ok {
|
||||
logs[i].Account = acc
|
||||
@@ -1642,7 +1668,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
|
||||
|
||||
for i := range logs {
|
||||
userIDs[logs[i].UserID] = struct{}{}
|
||||
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
|
||||
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
|
||||
accountIDs[logs[i].AccountID] = struct{}{}
|
||||
if logs[i].GroupID != nil {
|
||||
groupIDs[*logs[i].GroupID] = struct{}{}
|
||||
@@ -1676,12 +1702,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
|
||||
out := make(map[int64]*service.ApiKey)
|
||||
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
|
||||
out := make(map[int64]*service.APIKey)
|
||||
if len(ids) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
|
||||
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1800,7 +1826,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
log := &service.UsageLog{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
Model: model,
|
||||
InputTokens: inputTokens,
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
@@ -35,11 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UsageLogRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.New().String(), // Generate unique RequestID for each log
|
||||
Model: "claude-3",
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
@@ -47,7 +50,8 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
|
||||
ActualCost: cost,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log))
|
||||
_, err := s.repo.Create(s.ctx, log)
|
||||
s.Require().NoError(err)
|
||||
return log
|
||||
}
|
||||
|
||||
@@ -55,12 +59,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
|
||||
|
||||
func (s *UsageLogRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
@@ -69,14 +73,14 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
||||
ActualCost: 0.4,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, log)
|
||||
_, err := s.repo.Create(s.ctx, log)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(log.ID)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -96,7 +100,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -112,7 +116,7 @@ func (s *UsageLogRepoSuite) TestDelete() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -124,18 +128,18 @@ func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
// --- ListByApiKey ---
|
||||
// --- ListByAPIKey ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
func (s *UsageLogRepoSuite) TestListByAPIKey() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByApiKey")
|
||||
logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByAPIKey")
|
||||
s.Require().Len(logs, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
@@ -144,7 +148,7 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -159,7 +163,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -179,7 +183,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -211,8 +215,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
})
|
||||
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
|
||||
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
|
||||
@@ -223,7 +227,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
d1, d2, d3 := 100, 200, 300
|
||||
logToday := &service.UsageLog{
|
||||
UserID: userToday.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
GroupID: &group.ID,
|
||||
@@ -236,11 +240,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
DurationMs: &d1,
|
||||
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
|
||||
_, err = s.repo.Create(s.ctx, logToday)
|
||||
s.Require().NoError(err, "Create logToday")
|
||||
|
||||
logOld := &service.UsageLog{
|
||||
UserID: userOld.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
@@ -250,11 +255,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
DurationMs: &d2,
|
||||
CreatedAt: todayStart.Add(-1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
|
||||
_, err = s.repo.Create(s.ctx, logOld)
|
||||
s.Require().NoError(err, "Create logOld")
|
||||
|
||||
logPerf := &service.UsageLog{
|
||||
UserID: userToday.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 1,
|
||||
@@ -264,7 +270,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
DurationMs: &d3,
|
||||
CreatedAt: now.Add(-30 * time.Second),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf")
|
||||
_, err = s.repo.Create(s.ctx, logPerf)
|
||||
s.Require().NoError(err, "Create logPerf")
|
||||
|
||||
stats, err := s.repo.GetDashboardStats(s.ctx)
|
||||
s.Require().NoError(err, "GetDashboardStats")
|
||||
@@ -272,8 +279,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
|
||||
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
|
||||
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
|
||||
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
|
||||
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
|
||||
s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch")
|
||||
s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch")
|
||||
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
|
||||
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
|
||||
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
|
||||
@@ -300,14 +307,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetUserDashboardStats")
|
||||
s.Require().Equal(int64(1), stats.TotalApiKeys)
|
||||
s.Require().Equal(int64(1), stats.TotalAPIKeys)
|
||||
s.Require().Equal(int64(1), stats.TotalRequests)
|
||||
}
|
||||
|
||||
@@ -315,7 +322,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -331,8 +338,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
|
||||
|
||||
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
@@ -351,24 +358,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
|
||||
// --- GetBatchApiKeyUsageStats ---
|
||||
// --- GetBatchAPIKeyUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
|
||||
|
||||
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
@@ -377,7 +384,7 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -402,7 +409,7 @@ func maxTime(a, b time.Time) time.Time {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -417,11 +424,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
s.Require().Len(logs, 2)
|
||||
}
|
||||
|
||||
// --- ListByApiKeyAndTimeRange ---
|
||||
// --- ListByAPIKeyAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -431,8 +438,8 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
|
||||
logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "ListByAPIKeyAndTimeRange")
|
||||
s.Require().Len(logs, 2)
|
||||
}
|
||||
|
||||
@@ -440,7 +447,7 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -459,7 +466,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -467,7 +474,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
// Create logs with different models
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 10,
|
||||
@@ -476,11 +483,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: base,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
_, err := s.repo.Create(s.ctx, log1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 15,
|
||||
@@ -489,11 +497,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
ActualCost: 0.6,
|
||||
CreatedAt: base.Add(30 * time.Minute),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
_, err = s.repo.Create(s.ctx, log2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log3 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 20,
|
||||
@@ -502,7 +511,8 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
ActualCost: 0.7,
|
||||
CreatedAt: base.Add(1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log3))
|
||||
_, err = s.repo.Create(s.ctx, log3)
|
||||
s.Require().NoError(err)
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
@@ -515,7 +525,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
|
||||
|
||||
now := time.Now()
|
||||
@@ -535,7 +545,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -552,7 +562,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -571,7 +581,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -579,7 +589,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
// Create logs with different models
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
@@ -588,11 +598,12 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: base,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
_, err := s.repo.Create(s.ctx, log1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
@@ -601,7 +612,8 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
ActualCost: 0.2,
|
||||
CreatedAt: base.Add(1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
_, err = s.repo.Create(s.ctx, log2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
@@ -618,7 +630,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -646,7 +658,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -665,14 +677,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
@@ -681,11 +693,12 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: base,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
_, err := s.repo.Create(s.ctx, log1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
@@ -694,7 +707,8 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
ActualCost: 0.2,
|
||||
CreatedAt: base.Add(1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
_, err = s.repo.Create(s.ctx, log2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
@@ -719,7 +733,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
@@ -727,7 +741,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
// Create logs on different days
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
@@ -736,11 +750,12 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
ActualCost: 0.4,
|
||||
CreatedAt: base.Add(12 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
_, err := s.repo.Create(s.ctx, log1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
@@ -749,7 +764,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
ActualCost: 0.15,
|
||||
CreatedAt: base.Add(36 * time.Hour), // next day
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
_, err = s.repo.Create(s.ctx, log2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
startTime := base
|
||||
endTime := base.Add(72 * time.Hour)
|
||||
@@ -782,8 +798,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -799,12 +815,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
s.Require().GreaterOrEqual(len(trend), 2)
|
||||
}
|
||||
|
||||
// --- GetApiKeyUsageTrend ---
|
||||
// --- GetAPIKeyUsageTrend ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -815,14 +831,14 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||
s.Require().NoError(err, "GetApiKeyUsageTrend")
|
||||
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||
s.Require().NoError(err, "GetAPIKeyUsageTrend")
|
||||
s.Require().GreaterOrEqual(len(trend), 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -832,8 +848,8 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
|
||||
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
|
||||
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
|
||||
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
|
||||
@@ -841,12 +857,12 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
|
||||
filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID}
|
||||
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||
s.Require().NoError(err, "ListWithFilters apiKey")
|
||||
s.Require().Len(logs, 1)
|
||||
@@ -855,7 +871,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -874,7 +890,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -885,7 +901,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
StartTime: &startTime,
|
||||
EndTime: &endTime,
|
||||
}
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -17,14 +18,15 @@ import (
|
||||
|
||||
type userRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||||
return newUserRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
|
||||
return &userRepository{client: client}
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
|
||||
return &userRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||||
@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
// If attribute filters are specified, we need to filter by user IDs first
|
||||
var allowedUserIDs []int64
|
||||
if len(filters.Attributes) > 0 {
|
||||
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||||
var attrErr error
|
||||
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||||
if attrErr != nil {
|
||||
return nil, nil, attrErr
|
||||
}
|
||||
if len(allowedUserIDs) == 0 {
|
||||
// No users match the attribute filters
|
||||
return []service.User{}, paginationResultFromTotal(0, params), nil
|
||||
@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
}
|
||||
|
||||
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
|
||||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
|
||||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
|
||||
if len(attrs) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// For each attribute filter, get the set of matching user IDs
|
||||
// Then intersect all sets to get users matching ALL filters
|
||||
var resultSet map[int64]struct{}
|
||||
first := true
|
||||
if r.sql == nil {
|
||||
return nil, fmt.Errorf("sql executor is not configured")
|
||||
}
|
||||
|
||||
clauses := make([]string, 0, len(attrs))
|
||||
args := make([]any, 0, len(attrs)*2+1)
|
||||
argIndex := 1
|
||||
for attrID, value := range attrs {
|
||||
// Query user_attribute_values for this attribute
|
||||
values, err := r.client.UserAttributeValue.Query().
|
||||
Where(
|
||||
userattributevalue.AttributeIDEQ(attrID),
|
||||
userattributevalue.ValueContainsFold(value),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
currentSet := make(map[int64]struct{}, len(values))
|
||||
for _, v := range values {
|
||||
currentSet[v.UserID] = struct{}{}
|
||||
}
|
||||
|
||||
if first {
|
||||
resultSet = currentSet
|
||||
first = false
|
||||
} else {
|
||||
// Intersect with previous results
|
||||
for userID := range resultSet {
|
||||
if _, ok := currentSet[userID]; !ok {
|
||||
delete(resultSet, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early exit if no users match
|
||||
if len(resultSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
|
||||
args = append(args, attrID, "%"+value+"%")
|
||||
argIndex += 2
|
||||
}
|
||||
|
||||
result := make([]int64, 0, len(resultSet))
|
||||
for userID := range resultSet {
|
||||
query := fmt.Sprintf(
|
||||
`SELECT user_id
|
||||
FROM user_attribute_values
|
||||
WHERE %s
|
||||
GROUP BY user_id
|
||||
HAVING COUNT(DISTINCT attribute_id) = $%d`,
|
||||
strings.Join(clauses, " OR "),
|
||||
argIndex,
|
||||
)
|
||||
args = append(args, len(attrs))
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
result := make([]int64, 0)
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
if scanErr := rows.Scan(&userID); scanErr != nil {
|
||||
return nil, scanErr
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
|
||||
@@ -28,7 +28,7 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
|
||||
// ProviderSet is the Wire provider set for all repositories
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewUserRepository,
|
||||
NewApiKeyRepository,
|
||||
NewAPIKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewProxyRepository,
|
||||
@@ -42,7 +42,8 @@ var ProviderSet = wire.NewSet(
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
NewBillingCache,
|
||||
NewApiKeyCache,
|
||||
NewAPIKeyCache,
|
||||
NewTempUnschedCache,
|
||||
ProvideConcurrencyCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
|
||||
@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
name: "GET /api/v1/keys (paginated)",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
deps.apiKeyRepo.MustSeed(&service.ApiKey{
|
||||
deps.apiKeyRepo.MustSeed(&service.APIKey{
|
||||
ID: 100,
|
||||
UserID: 1,
|
||||
Key: "sk_custom_1234567890",
|
||||
@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
ApiKeyID: 100,
|
||||
APIKeyID: 100,
|
||||
AccountID: 200,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
{
|
||||
ID: 2,
|
||||
UserID: 1,
|
||||
ApiKeyID: 100,
|
||||
APIKeyID: 100,
|
||||
AccountID: 200,
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
ApiKeyID: 100,
|
||||
APIKeyID: 100,
|
||||
AccountID: 200,
|
||||
RequestID: "req_123",
|
||||
Model: "claude-3",
|
||||
@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) {
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
|
||||
service.SettingKeySmtpHost: "smtp.example.com",
|
||||
service.SettingKeySmtpPort: "587",
|
||||
service.SettingKeySmtpUsername: "user",
|
||||
service.SettingKeySmtpPassword: "secret",
|
||||
service.SettingKeySmtpFrom: "no-reply@example.com",
|
||||
service.SettingKeySmtpFromName: "Sub2API",
|
||||
service.SettingKeySmtpUseTLS: "true",
|
||||
service.SettingKeySMTPHost: "smtp.example.com",
|
||||
service.SettingKeySMTPPort: "587",
|
||||
service.SettingKeySMTPUsername: "user",
|
||||
service.SettingKeySMTPPassword: "secret",
|
||||
service.SettingKeySMTPFrom: "no-reply@example.com",
|
||||
service.SettingKeySMTPFromName: "Sub2API",
|
||||
service.SettingKeySMTPUseTLS: "true",
|
||||
|
||||
service.SettingKeyTurnstileEnabled: "true",
|
||||
service.SettingKeyTurnstileSiteKey: "site-key",
|
||||
@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
service.SettingKeySiteName: "Sub2API",
|
||||
service.SettingKeySiteLogo: "",
|
||||
service.SettingKeySiteSubtitle: "Subtitle",
|
||||
service.SettingKeyApiBaseUrl: "https://api.example.com",
|
||||
service.SettingKeyAPIBaseURL: "https://api.example.com",
|
||||
service.SettingKeyContactInfo: "support",
|
||||
service.SettingKeyDocUrl: "https://docs.example.com",
|
||||
service.SettingKeyDocURL: "https://docs.example.com",
|
||||
|
||||
service.SettingKeyDefaultConcurrency: "5",
|
||||
service.SettingKeyDefaultBalance: "1.25",
|
||||
@@ -308,7 +308,12 @@ func TestAPIContracts(t *testing.T) {
|
||||
"contact_info": "support",
|
||||
"doc_url": "https://docs.example.com",
|
||||
"default_concurrency": 5,
|
||||
"default_balance": 1.25
|
||||
"default_balance": 1.25,
|
||||
"enable_model_fallback": false,
|
||||
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
|
||||
"fallback_model_antigravity": "gemini-2.5-pro",
|
||||
"fallback_model_gemini": "gemini-2.5-pro",
|
||||
"fallback_model_openai": "gpt-4o"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
@@ -366,16 +371,16 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
ApiKeyPrefix: "sk-",
|
||||
APIKeyPrefix: "sk-",
|
||||
},
|
||||
RunMode: config.RunModeStandard,
|
||||
}
|
||||
|
||||
userService := service.NewUserService(userRepo)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
||||
|
||||
usageRepo := newStubUsageLogRepo()
|
||||
usageService := service.NewUsageService(usageRepo, userRepo)
|
||||
usageService := service.NewUsageService(usageRepo, userRepo, nil)
|
||||
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
@@ -664,20 +669,20 @@ type stubApiKeyRepo struct {
|
||||
now time.Time
|
||||
|
||||
nextID int64
|
||||
byID map[int64]*service.ApiKey
|
||||
byKey map[string]*service.ApiKey
|
||||
byID map[int64]*service.APIKey
|
||||
byKey map[string]*service.APIKey
|
||||
}
|
||||
|
||||
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
|
||||
return &stubApiKeyRepo{
|
||||
now: now,
|
||||
nextID: 100,
|
||||
byID: make(map[int64]*service.ApiKey),
|
||||
byKey: make(map[string]*service.ApiKey),
|
||||
byID: make(map[int64]*service.APIKey),
|
||||
byKey: make(map[string]*service.APIKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
|
||||
func (r *stubApiKeyRepo) MustSeed(key *service.APIKey) {
|
||||
if key == nil {
|
||||
return
|
||||
}
|
||||
@@ -686,7 +691,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
|
||||
r.byKey[clone.Key] = &clone
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
if key == nil {
|
||||
return errors.New("nil key")
|
||||
}
|
||||
@@ -706,10 +711,10 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
key, ok := r.byID[id]
|
||||
if !ok {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *key
|
||||
return &clone, nil
|
||||
@@ -718,26 +723,26 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey
|
||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
key, ok := r.byID[id]
|
||||
if !ok {
|
||||
return 0, service.ErrApiKeyNotFound
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return key.UserID, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
found, ok := r.byKey[key]
|
||||
if !ok {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *found
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
if key == nil {
|
||||
return errors.New("nil key")
|
||||
}
|
||||
if _, ok := r.byID[key.ID]; !ok {
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
if key.UpdatedAt.IsZero() {
|
||||
key.UpdatedAt = r.now
|
||||
@@ -751,14 +756,14 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error
|
||||
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
key, ok := r.byID[id]
|
||||
if !ok {
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
delete(r.byID, id)
|
||||
delete(r.byKey, key.Key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
ids := make([]int64, 0, len(r.byID))
|
||||
for id := range r.byID {
|
||||
if r.byID[id].UserID == userID {
|
||||
@@ -776,7 +781,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
|
||||
end = len(ids)
|
||||
}
|
||||
|
||||
out := make([]service.ApiKey, 0, end-start)
|
||||
out := make([]service.APIKey, 0, end-start)
|
||||
for _, id := range ids[start:end] {
|
||||
clone := *r.byID[id]
|
||||
out = append(out, clone)
|
||||
@@ -830,11 +835,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -858,8 +863,8 @@ func (r *stubUsageLogRepo) SetUserLogs(userID int64, logs []service.UsageLog) {
|
||||
r.userLogs[userID] = logs
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) error {
|
||||
return errors.New("not implemented")
|
||||
func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
||||
@@ -877,7 +882,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params
|
||||
return out, paginationResult(total, params), nil
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -890,7 +895,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in
|
||||
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -922,7 +927,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
|
||||
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -975,7 +980,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -995,7 +1000,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -1017,8 +1022,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio
|
||||
// Apply filters
|
||||
var filtered []service.UsageLog
|
||||
for _, log := range logs {
|
||||
// Apply ApiKeyID filter
|
||||
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID {
|
||||
// Apply APIKeyID filter
|
||||
if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID {
|
||||
continue
|
||||
}
|
||||
// Apply Model filter
|
||||
@@ -1151,8 +1156,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati
|
||||
// Ensure compile-time interface compliance.
|
||||
var (
|
||||
_ service.UserRepository = (*stubUserRepo)(nil)
|
||||
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil)
|
||||
_ service.ApiKeyCache = (*stubApiKeyCache)(nil)
|
||||
_ service.APIKeyRepository = (*stubApiKeyRepo)(nil)
|
||||
_ service.APIKeyCache = (*stubApiKeyCache)(nil)
|
||||
_ service.GroupRepository = (*stubGroupRepo)(nil)
|
||||
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
|
||||
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package server provides HTTP server initialization and configuration.
|
||||
package server
|
||||
|
||||
import (
|
||||
@@ -25,8 +26,8 @@ func ProvideRouter(
|
||||
handlers *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package middleware provides HTTP middleware for authentication, authorization, and request processing.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
@@ -32,7 +33,7 @@ func adminAuth(
|
||||
// 检查 x-api-key header(Admin API Key 认证)
|
||||
apiKey := c.GetHeader("x-api-key")
|
||||
if apiKey != "" {
|
||||
if !validateAdminApiKey(c, apiKey, settingService, userService) {
|
||||
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
@@ -57,14 +58,14 @@ func adminAuth(
|
||||
}
|
||||
}
|
||||
|
||||
// validateAdminApiKey 验证管理员 API Key
|
||||
func validateAdminApiKey(
|
||||
// validateAdminAPIKey 验证管理员 API Key
|
||||
func validateAdminAPIKey(
|
||||
c *gin.Context,
|
||||
key string,
|
||||
settingService *service.SettingService,
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
|
||||
storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
|
||||
return false
|
||||
|
||||
@@ -11,13 +11,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
|
||||
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
|
||||
// NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) APIKeyAuthMiddleware {
|
||||
return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
|
||||
}
|
||||
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试从Authorization header中提取API key (Bearer scheme)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
@@ -60,7 +60,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
// 从数据库验证API key
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrApiKeyNotFound) {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
|
||||
return
|
||||
}
|
||||
@@ -88,7 +88,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
@@ -146,7 +146,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
}
|
||||
|
||||
// 将API key和用户信息存入上下文
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
@@ -157,13 +157,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
}
|
||||
}
|
||||
|
||||
// GetApiKeyFromContext 从上下文中获取API key
|
||||
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
|
||||
value, exists := c.Get(string(ContextKeyApiKey))
|
||||
// GetAPIKeyFromContext 从上下文中获取API key
|
||||
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
|
||||
value, exists := c.Get(string(ContextKeyAPIKey))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
apiKey, ok := value.(*service.ApiKey)
|
||||
apiKey, ok := value.(*service.APIKey)
|
||||
return apiKey, ok
|
||||
}
|
||||
|
||||
|
||||
@@ -11,16 +11,16 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
|
||||
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
|
||||
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
|
||||
// APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
|
||||
func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
|
||||
return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
|
||||
}
|
||||
|
||||
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
|
||||
// APIKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
|
||||
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
|
||||
//
|
||||
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
|
||||
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKeyString := extractAPIKeyFromRequest(c)
|
||||
if apiKeyString == "" {
|
||||
@@ -30,7 +30,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
|
||||
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrApiKeyNotFound) {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
abortWithGoogleError(c, 401, "Invalid API key")
|
||||
return
|
||||
}
|
||||
@@ -53,7 +53,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
|
||||
|
||||
// 简易模式:跳过余额和订阅检查
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
@@ -92,7 +92,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
|
||||
}
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
|
||||
@@ -16,53 +16,53 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
type fakeAPIKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
}
|
||||
|
||||
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if f.getByKey == nil {
|
||||
return nil, errors.New("unexpected call")
|
||||
}
|
||||
return f.getByKey(ctx, key)
|
||||
}
|
||||
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -74,8 +74,8 @@ type googleErrorResponse struct {
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
|
||||
return service.NewApiKeyService(
|
||||
func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
|
||||
return service.NewAPIKeyService(
|
||||
repo,
|
||||
nil, // userRepo (unused in GetByKey)
|
||||
nil, // groupRepo
|
||||
@@ -89,12 +89,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -113,12 +113,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -138,12 +138,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, errors.New("db down")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -163,9 +163,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return &service.APIKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusDisabled,
|
||||
@@ -176,7 +176,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -196,9 +196,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return &service.APIKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusActive,
|
||||
@@ -210,7 +210,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.ApiKey{
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
@@ -46,9 +46,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
|
||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
@@ -110,9 +110,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
@@ -120,14 +120,14 @@ func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService
|
||||
}
|
||||
|
||||
type stubApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -135,14 +135,14 @@ func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if r.getByKey != nil {
|
||||
return r.getByKey(ctx, key)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -166,11 +166,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@ const (
|
||||
ContextKeyUser ContextKey = "user"
|
||||
// ContextKeyUserRole 当前用户角色(string)
|
||||
ContextKeyUserRole ContextKey = "user_role"
|
||||
// ContextKeyApiKey API密钥上下文键
|
||||
ContextKeyApiKey ContextKey = "api_key"
|
||||
// ContextKeyAPIKey API密钥上下文键
|
||||
ContextKeyAPIKey ContextKey = "api_key"
|
||||
// ContextKeySubscription 订阅上下文键
|
||||
ContextKeySubscription ContextKey = "subscription"
|
||||
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
|
||||
|
||||
@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
|
||||
// AdminAuthMiddleware 管理员认证中间件类型
|
||||
type AdminAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ApiKeyAuthMiddleware API Key 认证中间件类型
|
||||
type ApiKeyAuthMiddleware gin.HandlerFunc
|
||||
// APIKeyAuthMiddleware API Key 认证中间件类型
|
||||
type APIKeyAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ProviderSet 中间件层的依赖注入
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewJWTAuthMiddleware,
|
||||
NewAdminAuthMiddleware,
|
||||
NewApiKeyAuthMiddleware,
|
||||
NewAPIKeyAuthMiddleware,
|
||||
)
|
||||
|
||||
@@ -17,8 +17,8 @@ func SetupRouter(
|
||||
handlers *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) *gin.Engine {
|
||||
@@ -43,8 +43,8 @@ func registerRoutes(
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package routes provides HTTP route registration and handlers.
|
||||
package routes
|
||||
|
||||
import (
|
||||
@@ -67,10 +68,10 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
|
||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,6 +124,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
@@ -203,12 +206,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
{
|
||||
adminSettings.GET("", h.Admin.Setting.GetSettings)
|
||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
|
||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||
// Admin API Key 管理
|
||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
|
||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
|
||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
|
||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,7 +251,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
usage.GET("", h.Admin.Usage.List)
|
||||
usage.GET("/stats", h.Admin.Usage.Stats)
|
||||
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
|
||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
func RegisterGatewayRoutes(
|
||||
r *gin.Engine,
|
||||
h *handler.Handlers,
|
||||
apiKeyAuth middleware.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
apiKeyAuth middleware.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
@@ -36,7 +36,7 @@ func RegisterGatewayRoutes(
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
gemini := r.Group("/v1beta")
|
||||
gemini.Use(bodyLimit)
|
||||
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
{
|
||||
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
@@ -65,7 +65,7 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1Beta := r.Group("/antigravity/v1beta")
|
||||
antigravityV1Beta.Use(bodyLimit)
|
||||
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
{
|
||||
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
|
||||
@@ -50,7 +50,7 @@ func RegisterUserRoutes(
|
||||
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
|
||||
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
|
||||
usage.GET("/dashboard/models", h.Usage.DashboardModels)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
|
||||
}
|
||||
|
||||
// 卡密兑换
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package service provides business logic and domain services for the application.
|
||||
package service
|
||||
|
||||
import (
|
||||
@@ -29,6 +30,9 @@ type Account struct {
|
||||
RateLimitResetAt *time.Time
|
||||
OverloadUntil *time.Time
|
||||
|
||||
TempUnschedulableUntil *time.Time
|
||||
TempUnschedulableReason string
|
||||
|
||||
SessionWindowStart *time.Time
|
||||
SessionWindowEnd *time.Time
|
||||
SessionWindowStatus string
|
||||
@@ -39,6 +43,13 @@ type Account struct {
|
||||
Groups []*Group
|
||||
}
|
||||
|
||||
type TempUnschedulableRule struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
Keywords []string `json:"keywords"`
|
||||
DurationMinutes int `json:"duration_minutes"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
func (a *Account) IsActive() bool {
|
||||
return a.Status == StatusActive
|
||||
}
|
||||
@@ -54,6 +65,9 @@ func (a *Account) IsSchedulable() bool {
|
||||
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
|
||||
return false
|
||||
}
|
||||
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -92,10 +106,7 @@ func (a *Account) GeminiOAuthType() string {
|
||||
|
||||
func (a *Account) GeminiTierID() string {
|
||||
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
|
||||
if tierID == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ToUpper(tierID)
|
||||
return tierID
|
||||
}
|
||||
|
||||
func (a *Account) IsGeminiCodeAssist() bool {
|
||||
@@ -163,6 +174,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) IsTempUnschedulableEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := a.Credentials["temp_unschedulable_enabled"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := raw.(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["temp_unschedulable_rules"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
arr, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rules := make([]TempUnschedulableRule, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
entry, ok := item.(map[string]any)
|
||||
if !ok || entry == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
rule := TempUnschedulableRule{
|
||||
ErrorCode: parseTempUnschedInt(entry["error_code"]),
|
||||
Keywords: parseTempUnschedStrings(entry["keywords"]),
|
||||
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
|
||||
Description: parseTempUnschedString(entry["description"]),
|
||||
}
|
||||
|
||||
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
func parseTempUnschedString(value any) string {
|
||||
s, ok := value.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func parseTempUnschedStrings(value any) []string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var raw []string
|
||||
switch v := value.(type) {
|
||||
case []string:
|
||||
raw = v
|
||||
case []any:
|
||||
raw = make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
raw = append(raw, s)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
s := strings.TrimSpace(item)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseTempUnschedInt(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
@@ -206,7 +325,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
if a.Type != AccountTypeApiKey {
|
||||
if a.Type != AccountTypeAPIKey {
|
||||
return ""
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
@@ -229,7 +348,7 @@ func (a *Account) GetExtraString(key string) string {
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeApiKey || a.Credentials == nil {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
|
||||
@@ -301,14 +420,14 @@ func (a *Account) IsOpenAIOAuth() bool {
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeApiKey {
|
||||
if a.Type == AccountTypeAPIKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
|
||||
@@ -49,6 +49,8 @@ type AccountRepository interface {
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
|
||||
@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim
|
||||
panic("unexpected SetOverloaded call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
panic("unexpected SetTempUnschedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearTempUnschedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearRateLimit call")
|
||||
}
|
||||
|
||||
@@ -369,7 +369,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
@@ -393,7 +393,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
var err error
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
|
||||
@@ -12,16 +12,18 @@ import (
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
Create(ctx context.Context, log *UsageLog) error
|
||||
// Create creates a usage log and returns whether it was actually inserted.
|
||||
// inserted is false when the insert was skipped due to conflict (idempotent retries).
|
||||
Create(ctx context.Context, log *UsageLog) (inserted bool, err error)
|
||||
GetByID(ctx context.Context, id int64) (*UsageLog, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
@@ -32,10 +34,10 @@ type UsageLogRepository interface {
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
||||
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
|
||||
// User dashboard stats
|
||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||
@@ -51,7 +53,7 @@ type UsageLogRepository interface {
|
||||
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
@@ -105,6 +107,8 @@ type UsageProgress struct {
|
||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
UsedRequests int64 `json:"used_requests,omitempty"`
|
||||
LimitRequests int64 `json:"limit_requests,omitempty"`
|
||||
}
|
||||
|
||||
// AntigravityModelQuota Antigravity 单个模型的配额信息
|
||||
@@ -115,12 +119,16 @@ type AntigravityModelQuota struct {
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
||||
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||
GeminiSharedDaily *UsageProgress `json:"gemini_shared_daily,omitempty"` // Gemini shared pool RPD (Google One / Code Assist)
|
||||
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
||||
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
||||
GeminiSharedMinute *UsageProgress `json:"gemini_shared_minute,omitempty"` // Gemini shared pool RPM (Google One / Code Assist)
|
||||
GeminiProMinute *UsageProgress `json:"gemini_pro_minute,omitempty"` // Gemini Pro RPM
|
||||
GeminiFlashMinute *UsageProgress `json:"gemini_flash_minute,omitempty"` // Gemini Flash RPM
|
||||
|
||||
// Antigravity 多模型配额
|
||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||
@@ -256,17 +264,44 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
start := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
dayStart := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
|
||||
totals := geminiAggregateUsage(stats)
|
||||
resetAt := geminiDailyResetTime(now)
|
||||
dayTotals := geminiAggregateUsage(stats)
|
||||
dailyResetAt := geminiDailyResetTime(now)
|
||||
|
||||
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
|
||||
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
|
||||
// Daily window (RPD)
|
||||
if quota.SharedRPD > 0 {
|
||||
totalReq := dayTotals.ProRequests + dayTotals.FlashRequests
|
||||
totalTokens := dayTotals.ProTokens + dayTotals.FlashTokens
|
||||
totalCost := dayTotals.ProCost + dayTotals.FlashCost
|
||||
usage.GeminiSharedDaily = buildGeminiUsageProgress(totalReq, quota.SharedRPD, dailyResetAt, totalTokens, totalCost, now)
|
||||
} else {
|
||||
usage.GeminiProDaily = buildGeminiUsageProgress(dayTotals.ProRequests, quota.ProRPD, dailyResetAt, dayTotals.ProTokens, dayTotals.ProCost, now)
|
||||
usage.GeminiFlashDaily = buildGeminiUsageProgress(dayTotals.FlashRequests, quota.FlashRPD, dailyResetAt, dayTotals.FlashTokens, dayTotals.FlashCost, now)
|
||||
}
|
||||
|
||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
minuteResetAt := minuteStart.Add(time.Minute)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||
}
|
||||
minuteTotals := geminiAggregateUsage(minuteStats)
|
||||
|
||||
if quota.SharedRPM > 0 {
|
||||
totalReq := minuteTotals.ProRequests + minuteTotals.FlashRequests
|
||||
totalTokens := minuteTotals.ProTokens + minuteTotals.FlashTokens
|
||||
totalCost := minuteTotals.ProCost + minuteTotals.FlashCost
|
||||
usage.GeminiSharedMinute = buildGeminiUsageProgress(totalReq, quota.SharedRPM, minuteResetAt, totalTokens, totalCost, now)
|
||||
} else {
|
||||
usage.GeminiProMinute = buildGeminiUsageProgress(minuteTotals.ProRequests, quota.ProRPM, minuteResetAt, minuteTotals.ProTokens, minuteTotals.ProCost, now)
|
||||
usage.GeminiFlashMinute = buildGeminiUsageProgress(minuteTotals.FlashRequests, quota.FlashRPM, minuteResetAt, minuteTotals.FlashTokens, minuteTotals.FlashCost, now)
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
@@ -506,6 +541,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
||||
}
|
||||
|
||||
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
|
||||
// limit <= 0 means "no local quota window" (unknown or unlimited).
|
||||
if limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -519,6 +555,8 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
|
||||
Utilization: utilization,
|
||||
ResetsAt: &resetCopy,
|
||||
RemainingSeconds: remainingSeconds,
|
||||
UsedRequests: used,
|
||||
LimitRequests: limit,
|
||||
WindowStats: &WindowStats{
|
||||
Requests: used,
|
||||
Tokens: tokens,
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -19,7 +20,7 @@ type AdminService interface {
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
// Group management
|
||||
@@ -30,7 +31,7 @@ type AdminService interface {
|
||||
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||
DeleteGroup(ctx context.Context, id int64) error
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
||||
@@ -65,7 +66,7 @@ type AdminService interface {
|
||||
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
}
|
||||
|
||||
// Input types for admin operations
|
||||
// CreateUserInput represents input for creating a new user via admin operations.
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
@@ -122,18 +123,22 @@ type CreateAccountInput struct {
|
||||
Concurrency int
|
||||
Priority int
|
||||
GroupIDs []int64
|
||||
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
||||
// This should only be set when the caller has explicitly confirmed the risk.
|
||||
SkipMixedChannelCheck bool
|
||||
}
|
||||
|
||||
type UpdateAccountInput struct {
|
||||
Name string
|
||||
Type string // Account type: oauth, setup-token, apikey
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
Name string
|
||||
Type string // Account type: oauth, setup-token, apikey
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||
@@ -147,6 +152,9 @@ type BulkUpdateAccountsInput struct {
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
||||
// This should only be set when the caller has explicitly confirmed the risk.
|
||||
SkipMixedChannelCheck bool
|
||||
}
|
||||
|
||||
// BulkUpdateAccountResult captures the result for a single account update.
|
||||
@@ -220,7 +228,7 @@ type adminServiceImpl struct {
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo ApiKeyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
@@ -232,7 +240,7 @@ func NewAdminService(
|
||||
groupRepo GroupRepository,
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
@@ -430,7 +438,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
@@ -583,7 +591,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
@@ -620,6 +628,29 @@ func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
|
||||
// 绑定分组
|
||||
groupIDs := input.GroupIDs
|
||||
// 如果没有指定分组,自动绑定对应平台的默认分组
|
||||
if len(groupIDs) == 0 {
|
||||
defaultGroupName := input.Platform + "-default"
|
||||
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
|
||||
if err == nil {
|
||||
for _, g := range groups {
|
||||
if g.Name == defaultGroupName {
|
||||
groupIDs = []int64{g.ID}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
if len(groupIDs) > 0 && !input.SkipMixedChannelCheck {
|
||||
if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
Name: input.Name,
|
||||
Platform: input.Platform,
|
||||
@@ -637,22 +668,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
groupIDs := input.GroupIDs
|
||||
// 如果没有指定分组,自动绑定对应平台的默认分组
|
||||
if len(groupIDs) == 0 {
|
||||
defaultGroupName := input.Platform + "-default"
|
||||
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
|
||||
if err == nil {
|
||||
for _, g := range groups {
|
||||
if g.Name == defaultGroupName {
|
||||
groupIDs = []int64{g.ID}
|
||||
log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
|
||||
return nil, err
|
||||
@@ -703,6 +718,13 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
if !input.SkipMixedChannelCheck {
|
||||
if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
@@ -731,6 +753,20 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Preload account platforms for mixed channel risk checks if group bindings are requested.
|
||||
platformByID := map[int64]string{}
|
||||
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, account := range accounts {
|
||||
if account != nil {
|
||||
platformByID[account.ID] = account.Platform
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare bulk updates for columns and JSONB fields.
|
||||
repoUpdates := AccountBulkUpdate{
|
||||
Credentials: input.Credentials,
|
||||
@@ -762,6 +798,29 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
entry := BulkUpdateAccountResult{AccountID: accountID}
|
||||
|
||||
if input.GroupIDs != nil {
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
if !input.SkipMixedChannelCheck {
|
||||
platform := platformByID[accountID]
|
||||
if platform == "" {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
platform = account.Platform
|
||||
}
|
||||
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
@@ -1006,3 +1065,77 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
||||
Country: exitInfo.Country,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
|
||||
// 如果存在混合,返回错误提示用户确认
|
||||
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
// 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段)
|
||||
currentPlatform := getAccountPlatform(currentAccountPlatform)
|
||||
if currentPlatform == "" {
|
||||
// 不是 Antigravity 或 Anthropic,无需检查
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查每个分组中的其他账号
|
||||
for _, groupID := range groupIDs {
|
||||
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get accounts in group %d: %w", groupID, err)
|
||||
}
|
||||
|
||||
// 检查是否存在不同渠道的账号
|
||||
for _, account := range accounts {
|
||||
if currentAccountID > 0 && account.ID == currentAccountID {
|
||||
continue // 跳过当前账号
|
||||
}
|
||||
|
||||
otherPlatform := getAccountPlatform(account.Platform)
|
||||
if otherPlatform == "" {
|
||||
continue // 不是 Antigravity 或 Anthropic,跳过
|
||||
}
|
||||
|
||||
// 检测混合渠道
|
||||
if currentPlatform != otherPlatform {
|
||||
group, _ := s.groupRepo.GetByID(ctx, groupID)
|
||||
groupName := fmt.Sprintf("Group %d", groupID)
|
||||
if group != nil {
|
||||
groupName = group.Name
|
||||
}
|
||||
|
||||
return &MixedChannelError{
|
||||
GroupID: groupID,
|
||||
GroupName: groupName,
|
||||
CurrentPlatform: currentPlatform,
|
||||
OtherPlatform: otherPlatform,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
|
||||
func getAccountPlatform(accountPlatform string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
|
||||
case PlatformAntigravity:
|
||||
return "Antigravity"
|
||||
case PlatformAnthropic, "claude":
|
||||
return "Anthropic"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// MixedChannelError 混合渠道错误
|
||||
type MixedChannelError struct {
|
||||
GroupID int64
|
||||
GroupName string
|
||||
CurrentPlatform string
|
||||
OtherPlatform string
|
||||
}
|
||||
|
||||
func (e *MixedChannelError) Error() string {
|
||||
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
|
||||
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
|
||||
}
|
||||
|
||||
@@ -81,6 +81,7 @@ type AntigravityGatewayService struct {
|
||||
tokenProvider *AntigravityTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
settingService *SettingService
|
||||
}
|
||||
|
||||
func NewAntigravityGatewayService(
|
||||
@@ -89,12 +90,14 @@ func NewAntigravityGatewayService(
|
||||
tokenProvider *AntigravityTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
settingService *SettingService,
|
||||
) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
settingService: settingService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,6 +327,22 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// isModelNotFoundError 检测是否为模型不存在的 404 错误
|
||||
func isModelNotFoundError(statusCode int, body []byte) bool {
|
||||
if statusCode != 404 {
|
||||
return false
|
||||
}
|
||||
|
||||
bodyStr := strings.ToLower(string(body))
|
||||
keywords := []string{"model not found", "unknown model", "not found"}
|
||||
for _, keyword := range keywords {
|
||||
if strings.Contains(bodyStr, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return true // 404 without specific message also treated as model not found
|
||||
}
|
||||
|
||||
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
|
||||
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -417,16 +436,56 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
// 优先检测 thinking block 的 signature 相关错误(400)并重试一次:
|
||||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
||||
retryClaudeReq := claudeReq
|
||||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
||||
|
||||
stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq)
|
||||
if stripErr == nil && stripped {
|
||||
log.Printf("Antigravity account %d: detected signature-related 400, retrying once without thinking blocks", account.ID)
|
||||
|
||||
retryGeminiBody, txErr := antigravity.TransformClaudeToGemini(&retryClaudeReq, projectID, mappedModel)
|
||||
if txErr == nil {
|
||||
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
// Retry success: continue normal success flow with the new response.
|
||||
if retryResp.StatusCode < 400 {
|
||||
_ = resp.Body.Close()
|
||||
resp = retryResp
|
||||
respBody = nil
|
||||
} else {
|
||||
// Retry still errored: replace error context with retry response.
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
respBody = retryBody
|
||||
resp = retryResp
|
||||
}
|
||||
} else {
|
||||
log.Printf("Antigravity account %d: signature retry request failed: %v", account.ID, retryErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
||||
// 处理错误响应(重试后仍失败或不触发重试)
|
||||
if resp.StatusCode >= 400 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
||||
}
|
||||
}
|
||||
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
@@ -461,6 +520,122 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isSignatureRelatedError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
// Fallback: best-effort scan of the raw payload.
|
||||
msg = strings.ToLower(string(respBody))
|
||||
}
|
||||
|
||||
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
|
||||
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
|
||||
}
|
||||
|
||||
func extractAntigravityErrorMessage(body []byte) string {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Google-style: {"error": {"message": "..."}}
|
||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
return msg
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: top-level message
|
||||
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
return msg
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
||||
// This preserves the thinking content while avoiding signature validation errors.
|
||||
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
||||
// It also disables top-level `thinking` to prevent dummy-thought injection during retry.
|
||||
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
|
||||
if req == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
changed := false
|
||||
if req.Thinking != nil {
|
||||
req.Thinking = nil
|
||||
changed = true
|
||||
}
|
||||
|
||||
for i := range req.Messages {
|
||||
raw := req.Messages[i].Content
|
||||
if len(raw) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is a string, nothing to strip.
|
||||
var str string
|
||||
if json.Unmarshal(raw, &str) == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise treat as an array of blocks and convert thinking blocks to text.
|
||||
var blocks []map[string]any
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
filtered := make([]map[string]any, 0, len(blocks))
|
||||
modifiedAny := false
|
||||
for _, block := range blocks {
|
||||
t, _ := block["type"].(string)
|
||||
switch t {
|
||||
case "thinking":
|
||||
// Convert thinking to text, skip if empty
|
||||
thinkingText, _ := block["thinking"].(string)
|
||||
if thinkingText != "" {
|
||||
filtered = append(filtered, map[string]any{
|
||||
"type": "text",
|
||||
"text": thinkingText,
|
||||
})
|
||||
}
|
||||
modifiedAny = true
|
||||
case "redacted_thinking":
|
||||
// Remove redacted_thinking (cannot convert encrypted content)
|
||||
modifiedAny = true
|
||||
case "":
|
||||
// Handle untyped block with "thinking" field
|
||||
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
|
||||
if thinkingText != "" {
|
||||
filtered = append(filtered, map[string]any{
|
||||
"type": "text",
|
||||
"text": thinkingText,
|
||||
})
|
||||
}
|
||||
modifiedAny = true
|
||||
} else {
|
||||
filtered = append(filtered, block)
|
||||
}
|
||||
default:
|
||||
filtered = append(filtered, block)
|
||||
}
|
||||
}
|
||||
|
||||
if !modifiedAny {
|
||||
continue
|
||||
}
|
||||
|
||||
newRaw, err := json.Marshal(filtered)
|
||||
if err != nil {
|
||||
return changed, err
|
||||
}
|
||||
req.Messages[i].Content = newRaw
|
||||
changed = true
|
||||
}
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// ForwardGemini 转发 Gemini 协议请求
|
||||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -574,14 +749,40 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
|
||||
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
|
||||
isModelNotFoundError(resp.StatusCode, respBody) {
|
||||
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
|
||||
if fallbackModel != "" && fallbackModel != mappedModel {
|
||||
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
||||
|
||||
// 关闭原始响应,释放连接(respBody 已读取到内存)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
|
||||
if err == nil {
|
||||
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
|
||||
if err == nil {
|
||||
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err == nil && fallbackResp.StatusCode < 400 {
|
||||
resp = fallbackResp
|
||||
} else if fallbackResp != nil {
|
||||
_ = fallbackResp.Body.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fallback 成功:继续按正常响应处理
|
||||
if resp.StatusCode < 400 {
|
||||
goto handleSuccess
|
||||
}
|
||||
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
@@ -589,6 +790,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
|
||||
// 解包并返回错误
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
@@ -598,6 +803,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
handleSuccess:
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package service
|
||||
|
||||
import "time"
|
||||
|
||||
type ApiKey struct {
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
@@ -15,6 +15,6 @@ type ApiKey struct {
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (k *ApiKey) IsActive() bool {
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
|
||||
@@ -14,39 +14,39 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
|
||||
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(ctx context.Context, key *ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*ApiKey, error)
|
||||
type APIKeyRepository interface {
|
||||
Create(ctx context.Context, key *APIKey) error
|
||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
||||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
||||
GetByKey(ctx context.Context, key string) (*ApiKey, error)
|
||||
Update(ctx context.Context, key *ApiKey) error
|
||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
|
||||
// ApiKeyCache defines cache operations for API key service
|
||||
type ApiKeyCache interface {
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
type APIKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
@@ -55,40 +55,40 @@ type ApiKeyCache interface {
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
type CreateApiKeyRequest struct {
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
}
|
||||
|
||||
// UpdateApiKeyRequest 更新API Key请求
|
||||
type UpdateApiKeyRequest struct {
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo ApiKeyRepository
|
||||
// APIKeyService API Key服务
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache ApiKeyCache
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
func NewAPIKeyService(
|
||||
apiKeyRepo APIKeyRepository,
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache ApiKeyCache,
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
) *APIKeyService {
|
||||
return &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -99,7 +99,7 @@ func NewApiKeyService(
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
}
|
||||
|
||||
// 转换为十六进制字符串并添加前缀
|
||||
prefix := s.cfg.Default.ApiKeyPrefix
|
||||
prefix := s.cfg.Default.APIKeyPrefix
|
||||
if prefix == "" {
|
||||
prefix = "sk-"
|
||||
}
|
||||
@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
}
|
||||
|
||||
// ValidateCustomKey 验证自定义API Key格式
|
||||
func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
func (s *APIKeyService) ValidateCustomKey(key string) error {
|
||||
// 检查长度
|
||||
if len(key) < 16 {
|
||||
return ErrApiKeyTooShort
|
||||
return ErrAPIKeyTooShort
|
||||
}
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrApiKeyInvalidChars
|
||||
return ErrAPIKeyInvalidChars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
if count >= apiKeyMaxErrorsPerHour {
|
||||
return ErrApiKeyRateLimited
|
||||
return ErrAPIKeyRateLimited
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||||
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||||
@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
|
||||
}
|
||||
|
||||
// Create 创建API Key
|
||||
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
|
||||
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
// 判断是否使用自定义Key
|
||||
if req.CustomKey != nil && *req.CustomKey != "" {
|
||||
// 检查限流(仅对自定义key进行限流)
|
||||
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
|
||||
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -220,8 +220,8 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
if exists {
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementApiKeyErrorCount(ctx, userID)
|
||||
return nil, ErrApiKeyExists
|
||||
s.incrementAPIKeyErrorCount(ctx, userID)
|
||||
return nil, ErrAPIKeyExists
|
||||
}
|
||||
|
||||
key = *req.CustomKey
|
||||
@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &ApiKey{
|
||||
apiKey := &APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
|
||||
}
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
// 尝试从Redis缓存获取
|
||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||||
|
||||
@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
|
||||
}
|
||||
|
||||
// Update 更新API Key
|
||||
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
|
||||
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
|
||||
// Delete 删除API Key
|
||||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
||||
// 避免加载完整 ApiKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
||||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
}
|
||||
|
||||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||||
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
|
||||
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
|
||||
// 获取API Key
|
||||
apiKey, err := s.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
|
||||
}
|
||||
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.cache != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 返回用户可以选择的分组:
|
||||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||||
// - 订阅类型分组:用户有有效订阅的
|
||||
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
|
||||
}
|
||||
|
||||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||||
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
return subscribedGroupIDs[group.ID]
|
||||
@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
|
||||
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search api keys: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//go:build unit
|
||||
|
||||
// API Key 服务删除方法的单元测试
|
||||
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
|
||||
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
|
||||
// 包括权限验证、缓存清理和错误处理
|
||||
|
||||
package service
|
||||
@@ -16,12 +16,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
|
||||
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound)
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||
type apiKeyRepoStub struct {
|
||||
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
|
||||
|
||||
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
|
||||
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
|
||||
return s.ownerID, s.ownerErr
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
panic("unexpected SearchApiKeys call")
|
||||
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
// 设计说明:
|
||||
@@ -142,7 +142,7 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
|
||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 1}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
|
||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||
@@ -160,7 +160,7 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 7}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
|
||||
require.NoError(t, err)
|
||||
@@ -170,17 +170,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
|
||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
|
||||
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 99, 1)
|
||||
require.ErrorIs(t, err, ErrApiKeyNotFound)
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
require.Empty(t, cache.invalidated)
|
||||
}
|
||||
@@ -198,7 +198,7 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
|
||||
// 简易模式:跳过所有计费检查
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
|
||||
@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformAnthropic
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformOpenAI
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformGemini
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
|
||||
@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key usage trend: %w", err)
|
||||
}
|
||||
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const (
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeApiKey = "apikey" // API Key类型账号
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -64,13 +64,13 @@ const (
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySMTPPort = "smtp_port" // SMTP端口
|
||||
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySMTPPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySMTPFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
@@ -81,20 +81,27 @@ const (
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocUrl = "doc_url" // 文档链接
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
|
||||
// Gemini 配额策略(JSON)
|
||||
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
|
||||
|
||||
// Model fallback settings
|
||||
SettingKeyEnableModelFallback = "enable_model_fallback"
|
||||
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
|
||||
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
|
||||
SettingKeyFallbackModelGemini = "fallback_model_gemini"
|
||||
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
|
||||
)
|
||||
|
||||
// Admin API Key prefix (distinct from user "sk-" keys)
|
||||
const AdminApiKeyPrefix = "admin-"
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
|
||||
@@ -40,8 +40,8 @@ const (
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
// SmtpConfig SMTP配置
|
||||
type SmtpConfig struct {
|
||||
// SMTPConfig SMTP配置
|
||||
type SMTPConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
|
||||
}
|
||||
}
|
||||
|
||||
// GetSmtpConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
// GetSMTPConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||
keys := []string{
|
||||
SettingKeySmtpHost,
|
||||
SettingKeySmtpPort,
|
||||
SettingKeySmtpUsername,
|
||||
SettingKeySmtpPassword,
|
||||
SettingKeySmtpFrom,
|
||||
SettingKeySmtpFromName,
|
||||
SettingKeySmtpUseTLS,
|
||||
SettingKeySMTPHost,
|
||||
SettingKeySMTPPort,
|
||||
SettingKeySMTPUsername,
|
||||
SettingKeySMTPPassword,
|
||||
SettingKeySMTPFrom,
|
||||
SettingKeySMTPFromName,
|
||||
SettingKeySMTPUseTLS,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
return nil, fmt.Errorf("get smtp settings: %w", err)
|
||||
}
|
||||
|
||||
host := settings[SettingKeySmtpHost]
|
||||
host := settings[SettingKeySMTPHost]
|
||||
if host == "" {
|
||||
return nil, ErrEmailNotConfigured
|
||||
}
|
||||
|
||||
port := 587 // 默认端口
|
||||
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
|
||||
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
useTLS := settings[SettingKeySmtpUseTLS] == "true"
|
||||
useTLS := settings[SettingKeySMTPUseTLS] == "true"
|
||||
|
||||
return &SmtpConfig{
|
||||
return &SMTPConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: settings[SettingKeySmtpUsername],
|
||||
Password: settings[SettingKeySmtpPassword],
|
||||
From: settings[SettingKeySmtpFrom],
|
||||
FromName: settings[SettingKeySmtpFromName],
|
||||
Username: settings[SettingKeySMTPUsername],
|
||||
Password: settings[SettingKeySMTPPassword],
|
||||
From: settings[SettingKeySMTPFrom],
|
||||
FromName: settings[SettingKeySMTPFromName],
|
||||
UseTLS: useTLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendEmail 发送邮件(使用数据库中保存的配置)
|
||||
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
|
||||
config, err := s.GetSmtpConfig(ctx)
|
||||
config, err := s.GetSMTPConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
|
||||
}
|
||||
|
||||
// SendEmailWithConfig 使用指定配置发送邮件
|
||||
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
|
||||
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
|
||||
from := config.From
|
||||
if config.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
|
||||
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||
`, siteName, code)
|
||||
}
|
||||
|
||||
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseTLS {
|
||||
|
||||
@@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
|
||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
@@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
@@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
@@ -70,3 +71,224 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// FilterThinkingBlocks removes thinking blocks from request body
|
||||
// Returns filtered body or original body if filtering fails (fail-safe)
|
||||
// This prevents 400 errors from invalid thinking block signatures
|
||||
//
|
||||
// Strategy:
|
||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
||||
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
|
||||
func FilterThinkingBlocks(body []byte) []byte {
|
||||
return filterThinkingBlocksInternal(body, false)
|
||||
}
|
||||
|
||||
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios.
|
||||
// This is used when upstream returns signature-related 400 errors.
|
||||
//
|
||||
// Key insight:
|
||||
// - User's thinking.type = "enabled" should be PRESERVED (user's intent)
|
||||
// - Only HISTORICAL assistant messages have thinking blocks with signatures
|
||||
// - These signatures may be invalid when switching accounts/platforms
|
||||
// - New responses will generate fresh thinking blocks without signature issues
|
||||
//
|
||||
// Strategy:
|
||||
// - Keep thinking.type = "enabled" (preserve user intent)
|
||||
// - Remove thinking/redacted_thinking blocks from historical assistant messages
|
||||
// - Ensure no message has empty content after filtering
|
||||
func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
// Fast path: check for presence of thinking-related keys in messages
|
||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) {
|
||||
return body
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// DO NOT modify thinking.type - preserve user's intent to use thinking mode
|
||||
// The issue is with historical message signatures, not the thinking mode itself
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
newMessages := make([]any, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
newMessages = append(newMessages, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
// String content or other format - keep as is
|
||||
newMessages = append(newMessages, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
newContent := make([]any, 0, len(content))
|
||||
modifiedThisMsg := false
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
// Remove thinking/redacted_thinking blocks from historical messages
|
||||
// These have signatures that may be invalid across different accounts
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
modifiedThisMsg = true
|
||||
continue
|
||||
}
|
||||
|
||||
newContent = append(newContent, block)
|
||||
}
|
||||
|
||||
if modifiedThisMsg {
|
||||
modified = true
|
||||
// Handle empty content after filtering
|
||||
if len(newContent) == 0 {
|
||||
// For assistant messages, skip entirely (remove from conversation)
|
||||
// For user messages, add placeholder to avoid empty content error
|
||||
if role == "user" {
|
||||
newContent = append(newContent, map[string]any{
|
||||
"type": "text",
|
||||
"text": "(content removed)",
|
||||
})
|
||||
msgMap["content"] = newContent
|
||||
newMessages = append(newMessages, msgMap)
|
||||
}
|
||||
// Skip assistant messages with empty content (don't append)
|
||||
continue
|
||||
}
|
||||
msgMap["content"] = newContent
|
||||
}
|
||||
newMessages = append(newMessages, msgMap)
|
||||
}
|
||||
|
||||
if modified {
|
||||
req["messages"] = newMessages
|
||||
}
|
||||
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// filterThinkingBlocksInternal removes invalid thinking blocks from request
|
||||
// Strategy:
|
||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
||||
func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||
// Fast path: if body doesn't contain "thinking", skip parsing
|
||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking":`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking" :`)) {
|
||||
return body
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// Check if thinking is enabled
|
||||
thinkingEnabled := false
|
||||
if thinking, ok := req["thinking"].(map[string]any); ok {
|
||||
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
|
||||
thinkingEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
filtered := false
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
newContent := make([]any, 0, len(content))
|
||||
filteredThisMessage := false
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
// When thinking is enabled and this is an assistant message,
|
||||
// only keep thinking blocks with valid signatures
|
||||
if thinkingEnabled && role == "assistant" {
|
||||
signature, _ := blockMap["signature"].(string)
|
||||
if signature != "" && signature != "skip_thought_signature_validator" {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = true
|
||||
filteredThisMessage = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle blocks without type discriminator but with "thinking" key
|
||||
if blockType == "" {
|
||||
if _, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
filtered = true
|
||||
filteredThisMessage = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
newContent = append(newContent, block)
|
||||
}
|
||||
|
||||
if filteredThisMessage {
|
||||
msgMap["content"] = newContent
|
||||
}
|
||||
}
|
||||
|
||||
if !filtered {
|
||||
return body
|
||||
}
|
||||
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -38,3 +39,115 @@ func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
|
||||
_, err := ParseGatewayRequest(body)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocks(t *testing.T) {
|
||||
containsThinkingBlock := func(body []byte) bool {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return false
|
||||
}
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
if blockType == "thinking" {
|
||||
return true
|
||||
}
|
||||
if blockType == "" {
|
||||
if _, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
shouldFilter bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "filters thinking blocks",
|
||||
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "handles no thinking blocks",
|
||||
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
{
|
||||
name: "handles invalid JSON gracefully",
|
||||
input: `{invalid json`,
|
||||
shouldFilter: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "handles multiple messages with thinking blocks",
|
||||
input: `{"messages":[{"role":"user","content":[{"type":"text","text":"A"}]},{"role":"assistant","content":[{"type":"thinking","thinking":"think"},{"type":"text","text":"B"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "filters thinking blocks without type discriminator",
|
||||
input: `{"messages":[{"role":"assistant","content":[{"thinking":{"text":"internal"}},{"type":"text","text":"B"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "does not filter tool_use input fields named thinking",
|
||||
input: `{"messages":[{"role":"user","content":[{"type":"tool_use","id":"t1","name":"foo","input":{"thinking":"keepme","x":1}},{"type":"text","text":"Hello"}]}]}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
{
|
||||
name: "handles empty messages array",
|
||||
input: `{"messages":[]}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
{
|
||||
name: "handles missing messages field",
|
||||
input: `{"model":"claude-3"}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := FilterThinkingBlocks([]byte(tt.input))
|
||||
|
||||
if tt.expectError {
|
||||
// For invalid JSON, should return original
|
||||
require.Equal(t, tt.input, string(result))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.shouldFilter {
|
||||
require.False(t, containsThinkingBlock(result))
|
||||
} else {
|
||||
// Ensure we don't rewrite JSON when no filtering is needed.
|
||||
require.Equal(t, tt.input, string(result))
|
||||
}
|
||||
|
||||
// Verify valid JSON returned (unless input was invalid)
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(result, &parsed)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -543,7 +543,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
@@ -579,7 +579,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
@@ -710,7 +710,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
preferOAuth := platform == PlatformGemini
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
@@ -783,7 +783,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
@@ -799,7 +799,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
@@ -875,7 +875,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
@@ -907,7 +907,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
case AccountTypeOAuth, AccountTypeSetupToken:
|
||||
// Both oauth and setup-token use OAuth token flow
|
||||
return s.getOAuthToken(ctx, account)
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -1045,7 +1045,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
@@ -1082,8 +1082,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要重试
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
// 优先检测thinking block签名错误(400)并重试一次
|
||||
if resp.StatusCode == 400 {
|
||||
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if readErr == nil {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if s.isThinkingBlockSignatureError(respBody) {
|
||||
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
// 过滤thinking blocks并重试(使用更激进的过滤)
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
// 使用重试后的响应,继续后续处理
|
||||
if retryResp.StatusCode < 400 {
|
||||
log.Printf("Account %d: signature error retry succeeded", account.ID)
|
||||
} else {
|
||||
log.Printf("Account %d: signature error retry returned status %d", account.ID, retryResp.StatusCode)
|
||||
}
|
||||
resp = retryResp
|
||||
break
|
||||
}
|
||||
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
|
||||
} else {
|
||||
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
|
||||
}
|
||||
// 重试失败,恢复原始响应体继续处理
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
// 不是thinking签名错误,恢复响应体
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了)
|
||||
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if attempt < maxRetries {
|
||||
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
|
||||
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
|
||||
@@ -1096,6 +1133,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 不需要重试(成功或不可重试的错误),跳出循环
|
||||
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
|
||||
log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
||||
for k, v := range resp.Header {
|
||||
log.Printf("[DEBUG] %s: %v", k, v)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
@@ -1119,7 +1163,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if readErr != nil {
|
||||
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
@@ -1179,7 +1223,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
}
|
||||
@@ -1243,10 +1287,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
@@ -1313,12 +1357,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultApiKeyBetaHeader(body []byte) string {
|
||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||
modelID := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.ApiKeyHaikuBetaHeader
|
||||
return claude.APIKeyHaikuBetaHeader
|
||||
}
|
||||
return claude.ApiKeyBetaHeader
|
||||
return claude.APIKeyBetaHeader
|
||||
}
|
||||
|
||||
func truncateForLog(b []byte, maxBytes int) string {
|
||||
@@ -1335,6 +1379,41 @@ func truncateForLog(b []byte, maxBytes int) string {
|
||||
return s
|
||||
}
|
||||
|
||||
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
|
||||
// 这类错误可以通过过滤thinking blocks并重试来解决
|
||||
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Log for debugging
|
||||
log.Printf("[SignatureCheck] Checking error message: %s", msg)
|
||||
|
||||
// 检测signature相关的错误(更宽松的匹配)
|
||||
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
|
||||
if strings.Contains(msg, "signature") {
|
||||
log.Printf("[SignatureCheck] Detected signature error")
|
||||
return true
|
||||
}
|
||||
|
||||
// 检测 thinking block 顺序/类型错误
|
||||
// 例如: "Expected `thinking` or `redacted_thinking`, but found `text`"
|
||||
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
|
||||
log.Printf("[SignatureCheck] Detected thinking block type error")
|
||||
return true
|
||||
}
|
||||
|
||||
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
|
||||
// 例如: "all messages must have non-empty content"
|
||||
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
|
||||
log.Printf("[SignatureCheck] Detected empty content error")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
|
||||
// 默认保守:无法识别则不切换。
|
||||
@@ -1383,7 +1462,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 处理上游错误,标记账号状态
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
|
||||
var errType, errMsg string
|
||||
@@ -1695,7 +1780,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
ApiKey *ApiKey
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
@@ -1704,7 +1789,7 @@ type RecordUsageInput struct {
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -1741,7 +1826,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
@@ -1771,7 +1856,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -1781,10 +1867,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if cost.TotalCost > 0 {
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
}
|
||||
@@ -1793,7 +1881,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if cost.ActualCost > 0 {
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
}
|
||||
@@ -1826,7 +1914,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
@@ -1863,17 +1951,35 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
|
||||
// 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks)
|
||||
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
|
||||
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
filteredBody := FilterThinkingBlocks(body)
|
||||
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
resp = retryResp
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
// 标记账号状态(429/529等)
|
||||
@@ -1912,7 +2018,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
@@ -1971,10 +2077,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
return 999
|
||||
}
|
||||
switch a.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
|
||||
return 0
|
||||
}
|
||||
@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
originalModel := req.Model
|
||||
mappedModel := req.Model
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(req.Model)
|
||||
}
|
||||
|
||||
@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -539,7 +539,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
tempMatched := false
|
||||
if s.rateLimitService != nil {
|
||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
if tempMatched {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
@@ -614,7 +621,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
|
||||
@@ -636,7 +643,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -825,6 +832,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
tempMatched := false
|
||||
if s.rateLimitService != nil {
|
||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
|
||||
@@ -842,6 +853,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
@@ -1614,6 +1628,15 @@ type UpstreamHTTPResult struct {
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
||||
// Log response headers for debugging
|
||||
log.Printf("[GeminiAPI] ========== Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
log.Printf("[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiAPI] ========================================")
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1644,6 +1667,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
||||
// Log response headers for debugging
|
||||
log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
log.Printf("[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiAPI] ====================================================")
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
@@ -1758,7 +1790,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if apiKey == "" {
|
||||
return nil, errors.New("gemini api_key not configured")
|
||||
@@ -2177,10 +2209,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
|
||||
parts := make([]any, 0)
|
||||
switch content := mm["content"].(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(content) != "" {
|
||||
parts = append(parts, map[string]any{"text": content})
|
||||
}
|
||||
// 字符串形式的 content,保留所有内容(包括空白)
|
||||
parts = append(parts, map[string]any{"text": content})
|
||||
case []any:
|
||||
// 如果只有一个 block,不过滤空白(让上游 API 报错)
|
||||
singleBlock := len(content) == 1
|
||||
|
||||
for _, block := range content {
|
||||
bm, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
@@ -2189,8 +2223,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
|
||||
bt, _ := bm["type"].(string)
|
||||
switch bt {
|
||||
case "text":
|
||||
if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, map[string]any{"text": text})
|
||||
if text, ok := bm["text"].(string); ok {
|
||||
// 单个 block 时保留所有内容(包括空白)
|
||||
// 多个 blocks 时过滤掉空白
|
||||
if singleBlock || strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, map[string]any{"text": text})
|
||||
}
|
||||
}
|
||||
case "tool_use":
|
||||
id, _ := bm["id"].(string)
|
||||
|
||||
@@ -121,6 +121,12 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
|
||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
@@ -275,7 +281,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -18,12 +19,23 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
TierAIPremium = "AI_PREMIUM"
|
||||
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
|
||||
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
|
||||
TierFree = "FREE"
|
||||
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
|
||||
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
|
||||
// Canonical tier IDs used by sub2api (2026-aligned).
|
||||
GeminiTierGoogleOneFree = "google_one_free"
|
||||
GeminiTierGoogleAIPro = "google_ai_pro"
|
||||
GeminiTierGoogleAIUltra = "google_ai_ultra"
|
||||
GeminiTierGCPStandard = "gcp_standard"
|
||||
GeminiTierGCPEnterprise = "gcp_enterprise"
|
||||
GeminiTierAIStudioFree = "aistudio_free"
|
||||
GeminiTierAIStudioPaid = "aistudio_paid"
|
||||
GeminiTierGoogleOneUnknown = "google_one_unknown"
|
||||
|
||||
// Legacy/compat tier IDs that may exist in historical data or upstream responses.
|
||||
legacyTierAIPremium = "AI_PREMIUM"
|
||||
legacyTierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
|
||||
legacyTierGoogleOneBasic = "GOOGLE_ONE_BASIC"
|
||||
legacyTierFree = "FREE"
|
||||
legacyTierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
|
||||
legacyTierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -84,7 +96,7 @@ type GeminiAuthURLResult struct {
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
|
||||
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType, tierID string) (*GeminiAuthURLResult, error) {
|
||||
state, err := geminicli.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
@@ -109,14 +121,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
|
||||
// OAuth client selection:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
|
||||
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
|
||||
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
|
||||
// - ai_studio: requires a user-provided OAuth client.
|
||||
oauthCfg := geminicli.OAuthConfig{
|
||||
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
if oauthType == "code_assist" {
|
||||
oauthCfg.ClientID = ""
|
||||
oauthCfg.ClientSecret = ""
|
||||
}
|
||||
@@ -127,6 +139,7 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
ProxyURL: proxyURL,
|
||||
RedirectURI: redirectURI,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
TierID: canonicalGeminiTierIDForOAuthType(oauthType, tierID),
|
||||
OAuthType: oauthType,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
@@ -146,9 +159,9 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
|
||||
// Redirect URI strategy:
|
||||
// - code_assist: use Gemini CLI redirect URI (codeassist.google.com/authcode)
|
||||
// - ai_studio: use localhost callback for manual copy/paste flow
|
||||
if oauthType == "code_assist" {
|
||||
// - built-in Gemini CLI OAuth client: use upstream redirect URI (codeassist.google.com/authcode)
|
||||
// - custom OAuth client: use localhost callback for manual copy/paste flow
|
||||
if isBuiltinClient {
|
||||
redirectURI = geminicli.GeminiCLIRedirectURI
|
||||
} else {
|
||||
redirectURI = geminicli.AIStudioOAuthRedirectURI
|
||||
@@ -174,6 +187,9 @@ type GeminiExchangeCodeInput struct {
|
||||
Code string
|
||||
ProxyID *int64
|
||||
OAuthType string // "code_assist" 或 "ai_studio"
|
||||
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
|
||||
// If empty, the service will fall back to the tier stored in the OAuth session (if any).
|
||||
TierID string
|
||||
}
|
||||
|
||||
type GeminiTokenInfo struct {
|
||||
@@ -185,7 +201,7 @@ type GeminiTokenInfo struct {
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
||||
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
|
||||
TierID string `json:"tier_id,omitempty"` // Canonical tier id (e.g. google_one_free, gcp_standard, aistudio_free)
|
||||
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
|
||||
}
|
||||
|
||||
@@ -204,6 +220,90 @@ func validateTierID(tierID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func canonicalGeminiTierID(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lower := strings.ToLower(raw)
|
||||
switch lower {
|
||||
case GeminiTierGoogleOneFree,
|
||||
GeminiTierGoogleAIPro,
|
||||
GeminiTierGoogleAIUltra,
|
||||
GeminiTierGCPStandard,
|
||||
GeminiTierGCPEnterprise,
|
||||
GeminiTierAIStudioFree,
|
||||
GeminiTierAIStudioPaid,
|
||||
GeminiTierGoogleOneUnknown:
|
||||
return lower
|
||||
}
|
||||
|
||||
upper := strings.ToUpper(raw)
|
||||
switch upper {
|
||||
// Google One legacy tiers
|
||||
case legacyTierAIPremium:
|
||||
return GeminiTierGoogleAIPro
|
||||
case legacyTierGoogleOneUnlimited:
|
||||
return GeminiTierGoogleAIUltra
|
||||
case legacyTierFree, legacyTierGoogleOneBasic, legacyTierGoogleOneStandard:
|
||||
return GeminiTierGoogleOneFree
|
||||
case legacyTierGoogleOneUnknown:
|
||||
return GeminiTierGoogleOneUnknown
|
||||
|
||||
// Code Assist legacy tiers
|
||||
case "STANDARD", "PRO", "LEGACY":
|
||||
return GeminiTierGCPStandard
|
||||
case "ENTERPRISE", "ULTRA":
|
||||
return GeminiTierGCPEnterprise
|
||||
}
|
||||
|
||||
// Some Code Assist responses use kebab-case tier identifiers.
|
||||
switch lower {
|
||||
case "standard-tier", "pro-tier":
|
||||
return GeminiTierGCPStandard
|
||||
case "ultra-tier":
|
||||
return GeminiTierGCPEnterprise
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func canonicalGeminiTierIDForOAuthType(oauthType, tierID string) string {
|
||||
oauthType = strings.ToLower(strings.TrimSpace(oauthType))
|
||||
canonical := canonicalGeminiTierID(tierID)
|
||||
if canonical == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch oauthType {
|
||||
case "google_one":
|
||||
switch canonical {
|
||||
case GeminiTierGoogleOneFree, GeminiTierGoogleAIPro, GeminiTierGoogleAIUltra:
|
||||
return canonical
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
case "code_assist":
|
||||
switch canonical {
|
||||
case GeminiTierGCPStandard, GeminiTierGCPEnterprise:
|
||||
return canonical
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
case "ai_studio":
|
||||
switch canonical {
|
||||
case GeminiTierAIStudioFree, GeminiTierAIStudioPaid:
|
||||
return canonical
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
default:
|
||||
// Unknown oauth type: accept canonical tier.
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
|
||||
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
|
||||
// Prioritizes IsDefault tier, falls back to first non-empty tier
|
||||
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
|
||||
@@ -229,45 +329,61 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string
|
||||
|
||||
// inferGoogleOneTier infers Google One tier from Drive storage limit
|
||||
func inferGoogleOneTier(storageBytes int64) string {
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB))
|
||||
|
||||
if storageBytes <= 0 {
|
||||
return TierGoogleOneUnknown
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN")
|
||||
return GeminiTierGoogleOneUnknown
|
||||
}
|
||||
|
||||
if storageBytes > StorageTierUnlimited {
|
||||
return TierGoogleOneUnlimited
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited)
|
||||
return GeminiTierGoogleAIUltra
|
||||
}
|
||||
if storageBytes >= StorageTierAIPremium {
|
||||
return TierAIPremium
|
||||
}
|
||||
if storageBytes >= StorageTierStandard {
|
||||
return TierGoogleOneStandard
|
||||
}
|
||||
if storageBytes >= StorageTierBasic {
|
||||
return TierGoogleOneBasic
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium)
|
||||
return GeminiTierGoogleAIPro
|
||||
}
|
||||
if storageBytes >= StorageTierFree {
|
||||
return TierFree
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree)
|
||||
return GeminiTierGoogleOneFree
|
||||
}
|
||||
return TierGoogleOneUnknown
|
||||
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree)
|
||||
return GeminiTierGoogleOneUnknown
|
||||
}
|
||||
|
||||
// fetchGoogleOneTier fetches Google One tier from Drive API
|
||||
// FetchGoogleOneTier fetches Google One tier from Drive API.
|
||||
// Note: LoadCodeAssist API is NOT called for Google One accounts because:
|
||||
// 1. It's designed for GCP IAM (enterprise), not personal Google accounts
|
||||
// 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com
|
||||
// 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated
|
||||
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
|
||||
log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)")
|
||||
|
||||
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
|
||||
log.Printf("[GeminiOAuth] Calling Drive API for storage quota...")
|
||||
driveClient := geminicli.NewDriveClient()
|
||||
|
||||
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Check if it's a 403 (scope not granted)
|
||||
if strings.Contains(err.Error(), "status 403") {
|
||||
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
|
||||
return TierGoogleOneUnknown, nil, err
|
||||
log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err)
|
||||
return GeminiTierGoogleOneUnknown, nil, err
|
||||
}
|
||||
// Other errors
|
||||
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
|
||||
return TierGoogleOneUnknown, nil, err
|
||||
log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err)
|
||||
return GeminiTierGoogleOneUnknown, nil, err
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
|
||||
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
|
||||
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
|
||||
|
||||
tierID := inferGoogleOneTier(storageInfo.Limit)
|
||||
log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID)
|
||||
|
||||
return tierID, storageInfo, nil
|
||||
}
|
||||
|
||||
@@ -326,11 +442,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========")
|
||||
log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID)
|
||||
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
log.Printf("[GeminiOAuth] ERROR: Session not found or expired")
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if strings.TrimSpace(input.State) == "" || input.State != session.State {
|
||||
log.Printf("[GeminiOAuth] ERROR: Invalid state")
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
@@ -341,6 +462,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL)
|
||||
|
||||
redirectURI := session.RedirectURI
|
||||
|
||||
@@ -349,6 +471,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
|
||||
log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID)
|
||||
|
||||
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
|
||||
if oauthType == "ai_studio" {
|
||||
@@ -374,8 +498,13 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err)
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Token exchange successful")
|
||||
log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope)
|
||||
log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn)
|
||||
|
||||
sessionProjectID := strings.TrimSpace(session.ProjectID)
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
@@ -391,43 +520,91 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
|
||||
projectID := sessionProjectID
|
||||
var tierID string
|
||||
fallbackTierID := canonicalGeminiTierIDForOAuthType(oauthType, input.TierID)
|
||||
if fallbackTierID == "" {
|
||||
fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========")
|
||||
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
|
||||
|
||||
// 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API
|
||||
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别
|
||||
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
||||
switch oauthType {
|
||||
case "code_assist":
|
||||
log.Printf("[GeminiOAuth] Processing code_assist OAuth type")
|
||||
if projectID == "" {
|
||||
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
|
||||
var err error
|
||||
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// 记录警告但不阻断流程,允许后续补充 project_id
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err)
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID)
|
||||
// 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID
|
||||
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err)
|
||||
} else {
|
||||
tierID = fetchedTierID
|
||||
log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth")
|
||||
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
|
||||
}
|
||||
// tierID 缺失时使用默认值
|
||||
// Prefer auto-detected tier; fall back to user-selected tier.
|
||||
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
|
||||
if tierID == "" {
|
||||
tierID = "LEGACY"
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
|
||||
} else {
|
||||
tierID = GeminiTierGCPStandard
|
||||
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID)
|
||||
|
||||
case "google_one":
|
||||
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
|
||||
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
|
||||
// Attempt to fetch Drive storage tier
|
||||
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
|
||||
var storageInfo *geminicli.DriveStorageInfo
|
||||
var err error
|
||||
tierID, storageInfo, err = s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Log warning but don't block - use fallback
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
|
||||
tierID = TierGoogleOneUnknown
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err)
|
||||
tierID = ""
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID)
|
||||
if storageInfo != nil {
|
||||
log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
|
||||
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
|
||||
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
|
||||
}
|
||||
}
|
||||
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
|
||||
if tierID == "" || tierID == GeminiTierGoogleOneUnknown {
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
|
||||
} else {
|
||||
tierID = GeminiTierGoogleOneFree
|
||||
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID)
|
||||
|
||||
// Store Drive info in extra field for caching
|
||||
if storageInfo != nil {
|
||||
@@ -447,12 +624,25 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========")
|
||||
return tokenInfo, nil
|
||||
}
|
||||
}
|
||||
// ai_studio 模式不设置 tierID,保持为空
|
||||
|
||||
return &GeminiTokenInfo{
|
||||
case "ai_studio":
|
||||
// No automatic tier detection for AI Studio OAuth; rely on user selection.
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
} else {
|
||||
tierID = GeminiTierAIStudioFree
|
||||
}
|
||||
|
||||
default:
|
||||
log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========")
|
||||
|
||||
result := &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
@@ -462,7 +652,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
ProjectID: projectID,
|
||||
TierID: tierID,
|
||||
OAuthType: oauthType,
|
||||
}, nil
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID)
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========")
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
|
||||
@@ -558,6 +751,17 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
// Backward compatibility for google_one:
|
||||
// - New behavior: when a custom OAuth client is configured, google_one will use it.
|
||||
// - Old behavior: google_one always used the built-in Gemini CLI OAuth client.
|
||||
// If an existing account was authorized with the built-in client, refreshing with the custom client
|
||||
// will fail with "unauthorized_client". Retry with the built-in client (code_assist path forces it).
|
||||
if err != nil && oauthType == "google_one" && strings.Contains(err.Error(), "unauthorized_client") && s.GetOAuthConfig().AIStudioOAuthEnabled {
|
||||
if alt, altErr := s.RefreshToken(ctx, "code_assist", refreshToken, proxyURL); altErr == nil {
|
||||
tokenInfo = alt
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Provide a more actionable error for common OAuth client mismatch issues.
|
||||
if strings.Contains(err.Error(), "unauthorized_client") {
|
||||
@@ -583,13 +787,14 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
case "code_assist":
|
||||
// 先设置默认值或保留旧值,确保 tier_id 始终有值
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
} else {
|
||||
tokenInfo.TierID = "LEGACY" // 默认值
|
||||
tokenInfo.TierID = canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
|
||||
}
|
||||
if tokenInfo.TierID == "" {
|
||||
tokenInfo.TierID = GeminiTierGCPStandard
|
||||
}
|
||||
|
||||
// 尝试自动探测 project_id 和 tier_id
|
||||
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
|
||||
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || tokenInfo.TierID == ""
|
||||
if needDetect {
|
||||
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
@@ -598,9 +803,10 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
// 只有当原来没有 tier_id 且探测成功时才更新
|
||||
if existingTierID == "" && tierID != "" {
|
||||
tokenInfo.TierID = tierID
|
||||
if tierID != "" {
|
||||
if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" {
|
||||
tokenInfo.TierID = canonical
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -609,6 +815,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
|
||||
}
|
||||
case "google_one":
|
||||
canonicalExistingTier := canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
|
||||
// Check if tier cache is stale (> 24 hours)
|
||||
needsRefresh := true
|
||||
if account.Extra != nil {
|
||||
@@ -617,30 +824,37 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
if time.Since(updatedAt) <= 24*time.Hour {
|
||||
needsRefresh = false
|
||||
// Use cached tier
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
}
|
||||
tokenInfo.TierID = canonicalExistingTier
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tokenInfo.TierID == "" {
|
||||
tokenInfo.TierID = canonicalExistingTier
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err == nil && storageInfo != nil {
|
||||
tokenInfo.TierID = tierID
|
||||
tokenInfo.Extra = map[string]any{
|
||||
"drive_storage_limit": storageInfo.Limit,
|
||||
"drive_storage_usage": storageInfo.Usage,
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
if err == nil {
|
||||
if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" && canonical != GeminiTierGoogleOneUnknown {
|
||||
tokenInfo.TierID = canonical
|
||||
}
|
||||
if storageInfo != nil {
|
||||
tokenInfo.Extra = map[string]any{
|
||||
"drive_storage_limit": storageInfo.Limit,
|
||||
"drive_storage_usage": storageInfo.Usage,
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tokenInfo.TierID == "" || tokenInfo.TierID == GeminiTierGoogleOneUnknown {
|
||||
if canonicalExistingTier != "" {
|
||||
tokenInfo.TierID = canonicalExistingTier
|
||||
} else {
|
||||
// Fallback to cached or unknown
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
} else {
|
||||
tokenInfo.TierID = TierGoogleOneUnknown
|
||||
}
|
||||
tokenInfo.TierID = GeminiTierGoogleOneFree
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -669,6 +883,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
|
||||
// Validate tier_id before storing
|
||||
if err := validateTierID(tokenInfo.TierID); err == nil {
|
||||
creds["tier_id"] = tokenInfo.TierID
|
||||
fmt.Printf("[GeminiOAuth] Storing tier_id: %s\n", tokenInfo.TierID)
|
||||
} else {
|
||||
fmt.Printf("[GeminiOAuth] Invalid tier_id %s: %v\n", tokenInfo.TierID, err)
|
||||
}
|
||||
// Silently skip invalid tier_id (don't block account creation)
|
||||
}
|
||||
@@ -698,7 +915,13 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
|
||||
tierID := "LEGACY"
|
||||
if loadResp != nil {
|
||||
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||
// First try to get tier from currentTier/paidTier fields
|
||||
if tier := loadResp.GetTier(); tier != "" {
|
||||
tierID = tier
|
||||
} else {
|
||||
// Fallback to extracting from allowedTiers
|
||||
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||
}
|
||||
}
|
||||
|
||||
// If LoadCodeAssist returned a project, use it
|
||||
|
||||
@@ -1,50 +1,129 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
func TestInferGoogleOneTier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
storageBytes int64
|
||||
expectedTier string
|
||||
}{
|
||||
{"Negative storage", -1, TierGoogleOneUnknown},
|
||||
{"Zero storage", 0, TierGoogleOneUnknown},
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
)
|
||||
|
||||
// Free tier boundary (15GB)
|
||||
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
|
||||
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
|
||||
{"Free tier (15GB)", StorageTierFree, TierFree},
|
||||
func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Basic tier boundary (100GB)
|
||||
{"Between free and basic", 50 * GB, TierFree},
|
||||
{"Just below basic tier", StorageTierBasic - 1, TierFree},
|
||||
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
|
||||
type testCase struct {
|
||||
name string
|
||||
cfg *config.Config
|
||||
oauthType string
|
||||
projectID string
|
||||
wantClientID string
|
||||
wantRedirect string
|
||||
wantScope string
|
||||
wantProjectID string
|
||||
wantErrSubstr string
|
||||
}
|
||||
|
||||
// Standard tier boundary (200GB)
|
||||
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
|
||||
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
|
||||
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
|
||||
|
||||
// AI Premium tier boundary (2TB)
|
||||
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
|
||||
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
|
||||
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
|
||||
|
||||
// Unlimited tier boundary (> 100TB)
|
||||
{"Between premium and unlimited", 50 * TB, TierAIPremium},
|
||||
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
|
||||
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
|
||||
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
|
||||
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "google_one uses built-in client when not configured and redirects to upstream",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{},
|
||||
},
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: geminicli.GeminiCLIOAuthClientID,
|
||||
wantRedirect: geminicli.GeminiCLIRedirectURI,
|
||||
wantScope: geminicli.DefaultCodeAssistScopes,
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
name: "google_one uses custom client when configured and redirects to localhost",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
},
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
|
||||
wantScope: geminicli.DefaultGoogleOneScopes,
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
name: "code_assist always forces built-in client even when custom client configured",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
},
|
||||
},
|
||||
oauthType: "code_assist",
|
||||
projectID: "my-gcp-project",
|
||||
wantClientID: geminicli.GeminiCLIOAuthClientID,
|
||||
wantRedirect: geminicli.GeminiCLIRedirectURI,
|
||||
wantScope: geminicli.DefaultCodeAssistScopes,
|
||||
wantProjectID: "my-gcp-project",
|
||||
},
|
||||
{
|
||||
name: "ai_studio requires custom client",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{},
|
||||
},
|
||||
},
|
||||
oauthType: "ai_studio",
|
||||
wantErrSubstr: "AI Studio OAuth requires a custom OAuth Client",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := inferGoogleOneTier(tt.storageBytes)
|
||||
if result != tt.expectedTier {
|
||||
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
|
||||
tt.storageBytes, result, tt.expectedTier)
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
|
||||
got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
|
||||
if tt.wantErrSubstr != "" {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
|
||||
t.Fatalf("expected error containing %q, got: %v", tt.wantErrSubstr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAuthURL returned error: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(got.AuthURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse auth_url: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
|
||||
if gotState := q.Get("state"); gotState != got.State {
|
||||
t.Fatalf("state mismatch: query=%q result=%q", gotState, got.State)
|
||||
}
|
||||
if gotClientID := q.Get("client_id"); gotClientID != tt.wantClientID {
|
||||
t.Fatalf("client_id mismatch: got=%q want=%q", gotClientID, tt.wantClientID)
|
||||
}
|
||||
if gotRedirect := q.Get("redirect_uri"); gotRedirect != tt.wantRedirect {
|
||||
t.Fatalf("redirect_uri mismatch: got=%q want=%q", gotRedirect, tt.wantRedirect)
|
||||
}
|
||||
if gotScope := q.Get("scope"); gotScope != tt.wantScope {
|
||||
t.Fatalf("scope mismatch: got=%q want=%q", gotScope, tt.wantScope)
|
||||
}
|
||||
if gotProjectID := q.Get("project_id"); gotProjectID != tt.wantProjectID {
|
||||
t.Fatalf("project_id mismatch: got=%q want=%q", gotProjectID, tt.wantProjectID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,13 +20,24 @@ const (
|
||||
geminiModelFlash geminiModelClass = "flash"
|
||||
)
|
||||
|
||||
type GeminiDailyQuota struct {
|
||||
ProRPD int64
|
||||
FlashRPD int64
|
||||
type GeminiQuota struct {
|
||||
// SharedRPD is a shared requests-per-day pool across models.
|
||||
// When SharedRPD > 0, callers should treat ProRPD/FlashRPD as not applicable for daily quota checks.
|
||||
SharedRPD int64 `json:"shared_rpd,omitempty"`
|
||||
// SharedRPM is a shared requests-per-minute pool across models.
|
||||
// When SharedRPM > 0, callers should treat ProRPM/FlashRPM as not applicable for minute quota checks.
|
||||
SharedRPM int64 `json:"shared_rpm,omitempty"`
|
||||
|
||||
// Per-model quotas (AI Studio / API key).
|
||||
// A value of -1 means "unlimited" (pay-as-you-go).
|
||||
ProRPD int64 `json:"pro_rpd,omitempty"`
|
||||
ProRPM int64 `json:"pro_rpm,omitempty"`
|
||||
FlashRPD int64 `json:"flash_rpd,omitempty"`
|
||||
FlashRPM int64 `json:"flash_rpm,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiTierPolicy struct {
|
||||
Quota GeminiDailyQuota
|
||||
Quota GeminiQuota
|
||||
Cooldown time.Duration
|
||||
}
|
||||
|
||||
@@ -45,10 +56,27 @@ type GeminiUsageTotals struct {
|
||||
|
||||
const geminiQuotaCacheTTL = time.Minute
|
||||
|
||||
type geminiQuotaOverrides struct {
|
||||
type geminiQuotaOverridesV1 struct {
|
||||
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
|
||||
}
|
||||
|
||||
type geminiQuotaOverridesV2 struct {
|
||||
QuotaRules map[string]geminiQuotaRuleOverride `json:"quota_rules"`
|
||||
}
|
||||
|
||||
type geminiQuotaRuleOverride struct {
|
||||
SharedRPD *int64 `json:"shared_rpd,omitempty"`
|
||||
SharedRPM *int64 `json:"rpm,omitempty"`
|
||||
GeminiPro *geminiModelQuotaOverride `json:"gemini_pro,omitempty"`
|
||||
GeminiFlash *geminiModelQuotaOverride `json:"gemini_flash,omitempty"`
|
||||
Desc *string `json:"desc,omitempty"`
|
||||
}
|
||||
|
||||
type geminiModelQuotaOverride struct {
|
||||
RPD *int64 `json:"rpd,omitempty"`
|
||||
RPM *int64 `json:"rpm,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiQuotaService struct {
|
||||
cfg *config.Config
|
||||
settingRepo SettingRepository
|
||||
@@ -82,11 +110,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
|
||||
if s.cfg != nil {
|
||||
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
|
||||
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
|
||||
var overrides geminiQuotaOverrides
|
||||
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
|
||||
log.Printf("gemini quota: parse config policy failed: %v", err)
|
||||
raw := []byte(s.cfg.Gemini.Quota.Policy)
|
||||
var overridesV2 geminiQuotaOverridesV2
|
||||
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
|
||||
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
|
||||
} else {
|
||||
policy.ApplyOverrides(overrides.Tiers)
|
||||
var overridesV1 geminiQuotaOverridesV1
|
||||
if err := json.Unmarshal(raw, &overridesV1); err != nil {
|
||||
log.Printf("gemini quota: parse config policy failed: %v", err)
|
||||
} else {
|
||||
policy.ApplyOverrides(overridesV1.Tiers)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -96,11 +130,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
|
||||
if err != nil && !errors.Is(err, ErrSettingNotFound) {
|
||||
log.Printf("gemini quota: load setting failed: %v", err)
|
||||
} else if strings.TrimSpace(value) != "" {
|
||||
var overrides geminiQuotaOverrides
|
||||
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
|
||||
log.Printf("gemini quota: parse setting failed: %v", err)
|
||||
raw := []byte(value)
|
||||
var overridesV2 geminiQuotaOverridesV2
|
||||
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
|
||||
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
|
||||
} else {
|
||||
policy.ApplyOverrides(overrides.Tiers)
|
||||
var overridesV1 geminiQuotaOverridesV1
|
||||
if err := json.Unmarshal(raw, &overridesV1); err != nil {
|
||||
log.Printf("gemini quota: parse setting failed: %v", err)
|
||||
} else {
|
||||
policy.ApplyOverrides(overridesV1.Tiers)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -113,12 +153,20 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
|
||||
return policy
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
|
||||
if account == nil || !account.IsGeminiCodeAssist() {
|
||||
return GeminiDailyQuota{}, false
|
||||
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiQuota, bool) {
|
||||
if account == nil || account.Platform != PlatformGemini {
|
||||
return GeminiQuota{}, false
|
||||
}
|
||||
|
||||
// Map (oauth_type + tier_id) to a canonical policy tier key.
|
||||
// This keeps the policy table stable even if upstream tier_id strings vary.
|
||||
tierKey := geminiQuotaTierKeyForAccount(account)
|
||||
if tierKey == "" {
|
||||
return GeminiQuota{}, false
|
||||
}
|
||||
|
||||
policy := s.Policy(ctx)
|
||||
return policy.QuotaForTier(account.GeminiTierID())
|
||||
return policy.QuotaForTier(tierKey)
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
|
||||
@@ -126,12 +174,36 @@ func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string)
|
||||
return policy.CooldownForTier(tierID)
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) CooldownForAccount(ctx context.Context, account *Account) time.Duration {
|
||||
if s == nil || account == nil || account.Platform != PlatformGemini {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
tierKey := geminiQuotaTierKeyForAccount(account)
|
||||
if strings.TrimSpace(tierKey) == "" {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
return s.CooldownForTier(ctx, tierKey)
|
||||
}
|
||||
|
||||
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
|
||||
return &GeminiQuotaPolicy{
|
||||
tiers: map[string]GeminiTierPolicy{
|
||||
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
|
||||
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
|
||||
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
|
||||
// --- AI Studio / API Key (per-model) ---
|
||||
// aistudio_free:
|
||||
// - gemini_pro: 50 RPD / 2 RPM
|
||||
// - gemini_flash: 1500 RPD / 15 RPM
|
||||
GeminiTierAIStudioFree: {Quota: GeminiQuota{ProRPD: 50, ProRPM: 2, FlashRPD: 1500, FlashRPM: 15}, Cooldown: 30 * time.Minute},
|
||||
// aistudio_paid: -1 means "unlimited/pay-as-you-go" for RPD.
|
||||
GeminiTierAIStudioPaid: {Quota: GeminiQuota{ProRPD: -1, ProRPM: 1000, FlashRPD: -1, FlashRPM: 2000}, Cooldown: 5 * time.Minute},
|
||||
|
||||
// --- Google One (shared pool) ---
|
||||
GeminiTierGoogleOneFree: {Quota: GeminiQuota{SharedRPD: 1000, SharedRPM: 60}, Cooldown: 30 * time.Minute},
|
||||
GeminiTierGoogleAIPro: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
GeminiTierGoogleAIUltra: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
|
||||
// --- GCP Code Assist (shared pool) ---
|
||||
GeminiTierGCPStandard: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
GeminiTierGCPEnterprise: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -149,11 +221,22 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
|
||||
if !ok {
|
||||
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
|
||||
}
|
||||
// Backward-compatible overrides:
|
||||
// - If the tier uses shared quota, interpret pro_rpd as shared_rpd.
|
||||
// - Otherwise apply per-model overrides.
|
||||
if override.ProRPD != nil {
|
||||
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
|
||||
if policy.Quota.SharedRPD > 0 {
|
||||
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
|
||||
} else {
|
||||
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
|
||||
}
|
||||
}
|
||||
if override.FlashRPD != nil {
|
||||
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
|
||||
if policy.Quota.SharedRPD > 0 {
|
||||
// No separate flash RPD for shared tiers.
|
||||
} else {
|
||||
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.FlashRPD)
|
||||
}
|
||||
}
|
||||
if override.CooldownMinutes != nil {
|
||||
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
|
||||
@@ -163,10 +246,51 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
|
||||
func (p *GeminiQuotaPolicy) ApplyQuotaRulesOverrides(rules map[string]geminiQuotaRuleOverride) {
|
||||
if p == nil || len(rules) == 0 {
|
||||
return
|
||||
}
|
||||
for rawID, override := range rules {
|
||||
tierID := normalizeGeminiTierID(rawID)
|
||||
if tierID == "" {
|
||||
continue
|
||||
}
|
||||
policy, ok := p.tiers[tierID]
|
||||
if !ok {
|
||||
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
|
||||
}
|
||||
|
||||
if override.SharedRPD != nil {
|
||||
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.SharedRPD)
|
||||
}
|
||||
if override.SharedRPM != nil {
|
||||
policy.Quota.SharedRPM = clampGeminiQuotaRPM(*override.SharedRPM)
|
||||
}
|
||||
if override.GeminiPro != nil {
|
||||
if override.GeminiPro.RPD != nil {
|
||||
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiPro.RPD)
|
||||
}
|
||||
if override.GeminiPro.RPM != nil {
|
||||
policy.Quota.ProRPM = clampGeminiQuotaRPM(*override.GeminiPro.RPM)
|
||||
}
|
||||
}
|
||||
if override.GeminiFlash != nil {
|
||||
if override.GeminiFlash.RPD != nil {
|
||||
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiFlash.RPD)
|
||||
}
|
||||
if override.GeminiFlash.RPM != nil {
|
||||
policy.Quota.FlashRPM = clampGeminiQuotaRPM(*override.GeminiFlash.RPM)
|
||||
}
|
||||
}
|
||||
|
||||
p.tiers[tierID] = policy
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiQuota, bool) {
|
||||
policy, ok := p.policyForTier(tierID)
|
||||
if !ok {
|
||||
return GeminiDailyQuota{}, false
|
||||
return GeminiQuota{}, false
|
||||
}
|
||||
return policy.Quota, true
|
||||
}
|
||||
@@ -184,22 +308,43 @@ func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool
|
||||
return GeminiTierPolicy{}, false
|
||||
}
|
||||
normalized := normalizeGeminiTierID(tierID)
|
||||
if normalized == "" {
|
||||
normalized = "LEGACY"
|
||||
}
|
||||
if policy, ok := p.tiers[normalized]; ok {
|
||||
return policy, true
|
||||
}
|
||||
policy, ok := p.tiers["LEGACY"]
|
||||
return policy, ok
|
||||
return GeminiTierPolicy{}, false
|
||||
}
|
||||
|
||||
func normalizeGeminiTierID(tierID string) string {
|
||||
return strings.ToUpper(strings.TrimSpace(tierID))
|
||||
tierID = strings.TrimSpace(tierID)
|
||||
if tierID == "" {
|
||||
return ""
|
||||
}
|
||||
// Prefer canonical mapping (handles legacy tier strings).
|
||||
if canonical := canonicalGeminiTierID(tierID); canonical != "" {
|
||||
return canonical
|
||||
}
|
||||
// Accept older policy keys that used uppercase names.
|
||||
switch strings.ToUpper(tierID) {
|
||||
case "AISTUDIO_FREE":
|
||||
return GeminiTierAIStudioFree
|
||||
case "AISTUDIO_PAID":
|
||||
return GeminiTierAIStudioPaid
|
||||
case "GOOGLE_ONE_FREE":
|
||||
return GeminiTierGoogleOneFree
|
||||
case "GOOGLE_AI_PRO":
|
||||
return GeminiTierGoogleAIPro
|
||||
case "GOOGLE_AI_ULTRA":
|
||||
return GeminiTierGoogleAIUltra
|
||||
case "GCP_STANDARD":
|
||||
return GeminiTierGCPStandard
|
||||
case "GCP_ENTERPRISE":
|
||||
return GeminiTierGCPEnterprise
|
||||
}
|
||||
return strings.ToLower(tierID)
|
||||
}
|
||||
|
||||
func clampGeminiQuotaInt64(value int64) int64 {
|
||||
if value < 0 {
|
||||
func clampGeminiQuotaInt64WithUnlimited(value int64) int64 {
|
||||
if value < -1 {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
@@ -212,11 +357,46 @@ func clampGeminiQuotaInt(value int) int {
|
||||
return value
|
||||
}
|
||||
|
||||
func clampGeminiQuotaRPM(value int64) int64 {
|
||||
if value < 0 {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func geminiCooldownForTier(tierID string) time.Duration {
|
||||
policy := newGeminiQuotaPolicy()
|
||||
return policy.CooldownForTier(tierID)
|
||||
}
|
||||
|
||||
func geminiQuotaTierKeyForAccount(account *Account) string {
|
||||
if account == nil || account.Platform != PlatformGemini {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Note: GeminiOAuthType() already defaults legacy (project_id present) to code_assist.
|
||||
oauthType := strings.ToLower(strings.TrimSpace(account.GeminiOAuthType()))
|
||||
rawTier := strings.TrimSpace(account.GeminiTierID())
|
||||
|
||||
// Prefer the canonical tier stored in credentials.
|
||||
if tierID := canonicalGeminiTierIDForOAuthType(oauthType, rawTier); tierID != "" && tierID != GeminiTierGoogleOneUnknown {
|
||||
return tierID
|
||||
}
|
||||
|
||||
// Fallback defaults when tier_id is missing or unknown.
|
||||
switch oauthType {
|
||||
case "google_one":
|
||||
return GeminiTierGoogleOneFree
|
||||
case "code_assist":
|
||||
return GeminiTierGCPStandard
|
||||
case "ai_studio":
|
||||
return GeminiTierAIStudioFree
|
||||
default:
|
||||
// API Key accounts (type=apikey) have empty oauth_type and are treated as AI Studio.
|
||||
return GeminiTierAIStudioFree
|
||||
}
|
||||
}
|
||||
|
||||
func geminiModelClassFromName(model string) geminiModelClass {
|
||||
name := strings.ToLower(strings.TrimSpace(model))
|
||||
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
|
||||
|
||||
@@ -487,7 +487,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
case AccountTypeOAuth:
|
||||
// OAuth accounts use ChatGPT internal API
|
||||
targetURL = chatgptCodexURL
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
}
|
||||
|
||||
// Handle upstream error (mark account status)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// Return appropriate error response
|
||||
var errType, errMsg string
|
||||
@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
ApiKey *ApiKey
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct {
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
@@ -1020,22 +1026,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
_ = s.usageLogRepo.Create(ctx, usageLog)
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// Deduct based on billing type
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user