Compare commits

...

24 Commits

Author SHA1 Message Date
Wesley Liddick
a11ac188c2 Merge pull request #738 from DaydreamCoding/feat/ungrouped-key-setting
feat(gateway): 系统设置控制未分组 Key 调度 — Handler 层中间件拦截
2026-03-03 21:03:31 +08:00
Wesley Liddick
60350d298a Merge pull request #735 from alfadb/fix/count-tokens-default-ignore
fix(ops): 默认忽略 count_tokens 404 错误
2026-03-03 21:02:46 +08:00
shaw
838dad8759 feat: 重构 /v1/usage 端点,支持 quota_limited 和 unrestricted 双模式
- quota_limited 模式:返回 Key 级别的总额度、速率限制窗口用量和过期时间
- unrestricted 模式:返回订阅限额或钱包余额信息(向后兼容)
- 新增 model_stats 字段,支持 start_date/end_date 参数查询按模型用量统计
- 提取 buildUsageData/parseUsageDateRange 等辅助方法,减少主函数复杂度
- 新增 APIKeyService.GetRateLimitData 和 UsageService.GetAPIKeyModelStats
2026-03-03 20:59:12 +08:00
shaw
a728dfe0c6 refactor: 重构 api_key_auth 中间件,用 skipBilling 替代 7 处散落的 isUsageQuery
将中间件职责拆分为鉴权(Authentication)和计费执行(Billing Enforcement)两层:
- 鉴权层(disabled/IP/用户状态)始终执行
- 计费层(过期/配额/订阅/余额)用单一 skipBilling 守卫整块控制

/v1/usage 端点只需鉴权不需计费,skipBilling 仅出现 2 处(订阅加载错误处理 + 计费块守卫),
取代了之前 isUsageQuery 散布在 7 个 if 分支中的控制流。
2026-03-03 20:58:00 +08:00
QTom
0c7cbe3566 feat(gateway): 系统设置控制未分组 Key 调度 — Handler 层中间件拦截
新增系统设置 allow_ungrouped_key_scheduling(默认关闭),
未分组的 API Key 在网关请求时直接返回 403,
由 RequireGroupAssignment 中间件统一拦截,
支持 Anthropic / Google 两种错误格式响应。

全栈实现:常量 → 结构体 → 解析/更新/初始化 → DTO → 管理接口 →
中间件 → 路由注册 → 前端设置界面 + i18n。
2026-03-03 19:56:27 +08:00
alfadb
832b0185c7 style: fix gofmt formatting in ops_settings.go
Remove extra space before inline comment to pass golangci-lint gofmt check.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 18:00:49 +08:00
alfadb
b1719b26d1 fix(ops): 默认忽略 count_tokens 404 错误
将 IgnoreCountTokensErrors 默认值从 false 改为 true。

count_tokens 返回 404 是预期业务行为(上游不支持 endpoint,
客户端应 fallback 到本地 tokenizer 估算),不应被视为错误。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 16:50:13 +08:00
shaw
ccf6a921c7 fix: 修复 PR #723 引入的 CI lint 和 test 编译错误
- wire_gen_test.go: 补充 NewTokenRefreshService 缺失的 tempUnschedCache 参数
- config.go, token_refresh_service.go: 修复 gofmt 格式问题
2026-03-03 16:45:29 +08:00
Wesley Liddick
197c570baa Merge pull request #723 from zqq-nuli/fix/oauth-401-temp-unschedulable
fix: OAuth 401 不再永久锁死账号,改用临时不可调度实现自动恢复
2026-03-03 16:33:07 +08:00
shaw
0fe09f1d40 fix: 恢复 PR #682 中被误替换为占位符的 OAuth client_secret
PR #682 (release → main 全量同步) 将 Antigravity 和 Gemini CLI 的
OAuth client_secret 硬编码值替换为了 "GOCSPX-your-client-secret" 占位符,
导致未配置环境变量的部署环境中 token 刷新失败。

恢复内容:
- antigravity/oauth.go: 恢复真实 client_secret
- antigravity/oauth_test.go: 恢复测试断言中的真实值
- geminicli/constants.go: 恢复真实 client_secret
2026-03-03 16:27:28 +08:00
shaw
4a91954532 fix: correct migration 061 checksum and add missing BillingCache mock methods
- Fix fileChecksum for 061 migration: use TrimSpace hash (66207e7a) instead
  of raw sha256sum (97bdd9a3), matching the actual runtime computation
- Add 222b4a09 as accepted DB checksum for 061 migration
- Add missing GetAPIKeyRateLimit/SetAPIKeyRateLimit/UpdateAPIKeyRateLimitUsage/
  InvalidateAPIKeyRateLimit methods to mock BillingCache in test stubs
- Fix NewBillingCacheService call in singleflight test (add apiKeyRepo param)
2026-03-03 16:11:05 +08:00
shaw
b8b5cec35c fix: resolve CI lint errors and test compilation failures for rate limit feature
- Fix errcheck: properly handle rows.Close() error via named return + defer closure
- Fix gofmt: auto-format billing_cache.go, api_key_service.go, billing_cache_service.go
- Add missing rate limit interface methods to 4 test stubs (GetRateLimitData, IncrementRateLimitUsage, ResetRateLimitWindows)
- Fix NewBillingCacheService calls missing the new apiKeyRepo parameter
2026-03-03 15:43:08 +08:00
Wesley Liddick
43c203333e Merge pull request #733 from DaydreamCoding/fix/group-isolation
fix(gateway): 分组隔离 — 禁止未分组账号被跨组调度
2026-03-03 15:10:30 +08:00
Wesley Liddick
1c6393b131 Merge pull request #732 from xvhuan/perf/admin-dashboard-preagg
perf(admin): 优化 Dashboard 大数据量加载(预聚合趋势+异步用户趋势)
2026-03-03 15:10:20 +08:00
Wesley Liddick
22f04e72e5 Merge pull request #731 from xvhuan/fix/061-bounded-backfill-startup
fix(migrations): 061 迁移改为限时分批回填,避免启动阻塞导致 502
2026-03-03 15:08:18 +08:00
shaw
5f3debf65b chore: add migration for api key rate limit fields 2026-03-03 15:05:15 +08:00
shaw
fd8ef27535 Merge branch 'main' of github.com:Wei-Shaw/sub2api
# Conflicts:
2026-03-03 15:04:03 +08:00
shaw
a80ec5d8bb feat: apikey支持5h/1d/7d速率控制 2026-03-03 15:01:10 +08:00
QTom
530a16291c fix(gateway): 分组隔离 — 禁止未分组账号被跨组调度
当 API Key 无分组时,调度仅从未分组账号池中选取。
修复 isAccountInGroup 在 groupID==nil 时的逻辑,
同时补全 scheduler_snapshot_service 和 gemini_compat_service
中的 SimpleMode 保护,确保分组隔离在所有调度路径生效。

新增 ListSchedulableUngroupedByPlatform/s 方法,
使用 Ent 的 Not(HasAccountGroups()) 谓词实现未分组账号隔离。
新增 17 个单元和端到端隔离测试,覆盖所有分支和边界条件。
2026-03-03 13:20:58 +08:00
xvhuan
7be8f4dc6e perf(admin-dashboard): accelerate trend load with pre-aggregation and async user trend 2026-03-03 11:53:54 +08:00
Wesley Liddick
9792b17597 Merge pull request #729 from touwaeriol/pr/fix-admin-menu-visibility
fix(frontend): admin custom menu items not showing in sidebar
2026-03-03 11:35:59 +08:00
ius
99f1e3ff35 fix(migrations): avoid startup outage from 061 full-table backfill 2026-03-03 11:01:22 +08:00
erio
5ba71cd2f1 fix(frontend): admin custom menu items not showing in sidebar
The public settings API filters out menu items with visibility='admin',
so customMenuItemsForAdmin was always empty when reading from
cachedPublicSettings. Fix by loading custom menu items from the admin
settings API (via adminSettingsStore) which returns all items unfiltered.

Changes:
- adminSettings store: store custom_menu_items from admin settings API
- AppSidebar: read admin menu items from adminSettingsStore instead of
  cachedPublicSettings
- CustomPageView: merge public and admin menu items so admin users can
  access admin-only custom pages
2026-03-03 10:45:35 +08:00
zqq61
ec6bcfeb83 fix: OAuth 401 不再永久锁死账号,改用临时不可调度实现自动恢复
OAuth 账号收到 401 时,原逻辑同时设置 expires_at=now() 和 SetError(),
但刷新服务只查询 status=active 的账号,导致 error 状态的账号永远无法
被刷新服务拾取,expires_at=now() 实际上是死代码。

修复:
- OAuth 401 使用 SetTempUnschedulable 替代 SetError,保持 status=active
- 新增 oauth_401_cooldown_minutes 配置项(默认 10 分钟)
- 刷新成功后同步清除 DB 和 Redis 中的临时不可调度状态
- 不可重试错误检查(invalid_grant 等)从 Antigravity 推广到所有平台
- 可重试错误耗尽后不再标记 error,下个刷新周期继续重试

恢复流程:
OAuth 401 → temp_unschedulable + expires_at=now → 刷新服务拾取
  → 成功: 清除 temp_unschedulable → 自动恢复
  → invalid_grant: SetError → 永久禁用
  → 网络错误: 仅记日志 → 下周期重试
2026-03-02 22:54:38 +08:00
84 changed files with 5137 additions and 279 deletions

View File

@@ -58,11 +58,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
promoCodeRepository := repository.NewPromoCodeRepository(client) promoCodeRepository := repository.NewPromoCodeRepository(client)
billingCache := repository.NewBillingCache(redisClient) billingCache := repository.NewBillingCache(redisClient)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) apiKeyRepository := repository.NewAPIKeyRepository(client, db)
apiKeyRepository := repository.NewAPIKeyRepository(client) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig)
userGroupRateRepository := repository.NewUserGroupRateRepository(db) userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
apiKeyService.SetRateLimitCacheInvalidator(billingCache)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
@@ -221,7 +222,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)

View File

@@ -37,12 +37,13 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
nil, nil,
nil, nil,
cfg, cfg,
nil,
) )
accountExpirySvc := service.NewAccountExpiryService(nil, time.Second) accountExpirySvc := service.NewAccountExpiryService(nil, time.Second)
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
pricingSvc := service.NewPricingService(cfg, nil) pricingSvc := service.NewPricingService(cfg, nil)
emailQueueSvc := service.NewEmailQueueService(nil, 1) emailQueueSvc := service.NewEmailQueueService(nil, 1)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg) billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)

View File

@@ -48,6 +48,24 @@ type APIKey struct {
QuotaUsed float64 `json:"quota_used,omitempty"` QuotaUsed float64 `json:"quota_used,omitempty"`
// Expiration time for this API key (null = never expires) // Expiration time for this API key (null = never expires)
ExpiresAt *time.Time `json:"expires_at,omitempty"` ExpiresAt *time.Time `json:"expires_at,omitempty"`
// Rate limit in USD per 5 hours (0 = unlimited)
RateLimit5h float64 `json:"rate_limit_5h,omitempty"`
// Rate limit in USD per day (0 = unlimited)
RateLimit1d float64 `json:"rate_limit_1d,omitempty"`
// Rate limit in USD per 7 days (0 = unlimited)
RateLimit7d float64 `json:"rate_limit_7d,omitempty"`
// Used amount in USD for the current 5h window
Usage5h float64 `json:"usage_5h,omitempty"`
// Used amount in USD for the current 1d window
Usage1d float64 `json:"usage_1d,omitempty"`
// Used amount in USD for the current 7d window
Usage7d float64 `json:"usage_7d,omitempty"`
// Start time of the current 5h rate limit window
Window5hStart *time.Time `json:"window_5h_start,omitempty"`
// Start time of the current 1d rate limit window
Window1dStart *time.Time `json:"window_1d_start,omitempty"`
// Start time of the current 7d rate limit window
Window7dStart *time.Time `json:"window_7d_start,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the APIKeyQuery when eager-loading is set. // The values are being populated by the APIKeyQuery when eager-loading is set.
Edges APIKeyEdges `json:"edges"` Edges APIKeyEdges `json:"edges"`
@@ -105,13 +123,13 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
switch columns[i] { switch columns[i] {
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist: case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
values[i] = new([]byte) values[i] = new([]byte)
case apikey.FieldQuota, apikey.FieldQuotaUsed: case apikey.FieldQuota, apikey.FieldQuotaUsed, apikey.FieldRateLimit5h, apikey.FieldRateLimit1d, apikey.FieldRateLimit7d, apikey.FieldUsage5h, apikey.FieldUsage1d, apikey.FieldUsage7d:
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt: case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt, apikey.FieldWindow5hStart, apikey.FieldWindow1dStart, apikey.FieldWindow7dStart:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
default: default:
values[i] = new(sql.UnknownType) values[i] = new(sql.UnknownType)
@@ -226,6 +244,63 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
_m.ExpiresAt = new(time.Time) _m.ExpiresAt = new(time.Time)
*_m.ExpiresAt = value.Time *_m.ExpiresAt = value.Time
} }
case apikey.FieldRateLimit5h:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field rate_limit_5h", values[i])
} else if value.Valid {
_m.RateLimit5h = value.Float64
}
case apikey.FieldRateLimit1d:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field rate_limit_1d", values[i])
} else if value.Valid {
_m.RateLimit1d = value.Float64
}
case apikey.FieldRateLimit7d:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field rate_limit_7d", values[i])
} else if value.Valid {
_m.RateLimit7d = value.Float64
}
case apikey.FieldUsage5h:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field usage_5h", values[i])
} else if value.Valid {
_m.Usage5h = value.Float64
}
case apikey.FieldUsage1d:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field usage_1d", values[i])
} else if value.Valid {
_m.Usage1d = value.Float64
}
case apikey.FieldUsage7d:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field usage_7d", values[i])
} else if value.Valid {
_m.Usage7d = value.Float64
}
case apikey.FieldWindow5hStart:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field window_5h_start", values[i])
} else if value.Valid {
_m.Window5hStart = new(time.Time)
*_m.Window5hStart = value.Time
}
case apikey.FieldWindow1dStart:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field window_1d_start", values[i])
} else if value.Valid {
_m.Window1dStart = new(time.Time)
*_m.Window1dStart = value.Time
}
case apikey.FieldWindow7dStart:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field window_7d_start", values[i])
} else if value.Valid {
_m.Window7dStart = new(time.Time)
*_m.Window7dStart = value.Time
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
@@ -326,6 +401,39 @@ func (_m *APIKey) String() string {
builder.WriteString("expires_at=") builder.WriteString("expires_at=")
builder.WriteString(v.Format(time.ANSIC)) builder.WriteString(v.Format(time.ANSIC))
} }
builder.WriteString(", ")
builder.WriteString("rate_limit_5h=")
builder.WriteString(fmt.Sprintf("%v", _m.RateLimit5h))
builder.WriteString(", ")
builder.WriteString("rate_limit_1d=")
builder.WriteString(fmt.Sprintf("%v", _m.RateLimit1d))
builder.WriteString(", ")
builder.WriteString("rate_limit_7d=")
builder.WriteString(fmt.Sprintf("%v", _m.RateLimit7d))
builder.WriteString(", ")
builder.WriteString("usage_5h=")
builder.WriteString(fmt.Sprintf("%v", _m.Usage5h))
builder.WriteString(", ")
builder.WriteString("usage_1d=")
builder.WriteString(fmt.Sprintf("%v", _m.Usage1d))
builder.WriteString(", ")
builder.WriteString("usage_7d=")
builder.WriteString(fmt.Sprintf("%v", _m.Usage7d))
builder.WriteString(", ")
if v := _m.Window5hStart; v != nil {
builder.WriteString("window_5h_start=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.Window1dStart; v != nil {
builder.WriteString("window_1d_start=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.Window7dStart; v != nil {
builder.WriteString("window_7d_start=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }

View File

@@ -43,6 +43,24 @@ const (
FieldQuotaUsed = "quota_used" FieldQuotaUsed = "quota_used"
// FieldExpiresAt holds the string denoting the expires_at field in the database. // FieldExpiresAt holds the string denoting the expires_at field in the database.
FieldExpiresAt = "expires_at" FieldExpiresAt = "expires_at"
// FieldRateLimit5h holds the string denoting the rate_limit_5h field in the database.
FieldRateLimit5h = "rate_limit_5h"
// FieldRateLimit1d holds the string denoting the rate_limit_1d field in the database.
FieldRateLimit1d = "rate_limit_1d"
// FieldRateLimit7d holds the string denoting the rate_limit_7d field in the database.
FieldRateLimit7d = "rate_limit_7d"
// FieldUsage5h holds the string denoting the usage_5h field in the database.
FieldUsage5h = "usage_5h"
// FieldUsage1d holds the string denoting the usage_1d field in the database.
FieldUsage1d = "usage_1d"
// FieldUsage7d holds the string denoting the usage_7d field in the database.
FieldUsage7d = "usage_7d"
// FieldWindow5hStart holds the string denoting the window_5h_start field in the database.
FieldWindow5hStart = "window_5h_start"
// FieldWindow1dStart holds the string denoting the window_1d_start field in the database.
FieldWindow1dStart = "window_1d_start"
// FieldWindow7dStart holds the string denoting the window_7d_start field in the database.
FieldWindow7dStart = "window_7d_start"
// EdgeUser holds the string denoting the user edge name in mutations. // EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user" EdgeUser = "user"
// EdgeGroup holds the string denoting the group edge name in mutations. // EdgeGroup holds the string denoting the group edge name in mutations.
@@ -91,6 +109,15 @@ var Columns = []string{
FieldQuota, FieldQuota,
FieldQuotaUsed, FieldQuotaUsed,
FieldExpiresAt, FieldExpiresAt,
FieldRateLimit5h,
FieldRateLimit1d,
FieldRateLimit7d,
FieldUsage5h,
FieldUsage1d,
FieldUsage7d,
FieldWindow5hStart,
FieldWindow1dStart,
FieldWindow7dStart,
} }
// ValidColumn reports if the column name is valid (part of the table columns). // ValidColumn reports if the column name is valid (part of the table columns).
@@ -129,6 +156,18 @@ var (
DefaultQuota float64 DefaultQuota float64
// DefaultQuotaUsed holds the default value on creation for the "quota_used" field. // DefaultQuotaUsed holds the default value on creation for the "quota_used" field.
DefaultQuotaUsed float64 DefaultQuotaUsed float64
// DefaultRateLimit5h holds the default value on creation for the "rate_limit_5h" field.
DefaultRateLimit5h float64
// DefaultRateLimit1d holds the default value on creation for the "rate_limit_1d" field.
DefaultRateLimit1d float64
// DefaultRateLimit7d holds the default value on creation for the "rate_limit_7d" field.
DefaultRateLimit7d float64
// DefaultUsage5h holds the default value on creation for the "usage_5h" field.
DefaultUsage5h float64
// DefaultUsage1d holds the default value on creation for the "usage_1d" field.
DefaultUsage1d float64
// DefaultUsage7d holds the default value on creation for the "usage_7d" field.
DefaultUsage7d float64
) )
// OrderOption defines the ordering options for the APIKey queries. // OrderOption defines the ordering options for the APIKey queries.
@@ -199,6 +238,51 @@ func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
} }
// ByRateLimit5h orders the results by the rate_limit_5h field.
func ByRateLimit5h(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateLimit5h, opts...).ToFunc()
}
// ByRateLimit1d orders the results by the rate_limit_1d field.
func ByRateLimit1d(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateLimit1d, opts...).ToFunc()
}
// ByRateLimit7d orders the results by the rate_limit_7d field.
func ByRateLimit7d(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateLimit7d, opts...).ToFunc()
}
// ByUsage5h orders the results by the usage_5h field.
func ByUsage5h(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUsage5h, opts...).ToFunc()
}
// ByUsage1d orders the results by the usage_1d field.
func ByUsage1d(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUsage1d, opts...).ToFunc()
}
// ByUsage7d orders the results by the usage_7d field.
func ByUsage7d(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUsage7d, opts...).ToFunc()
}
// ByWindow5hStart orders the results by the window_5h_start field.
func ByWindow5hStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWindow5hStart, opts...).ToFunc()
}
// ByWindow1dStart orders the results by the window_1d_start field.
func ByWindow1dStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWindow1dStart, opts...).ToFunc()
}
// ByWindow7dStart orders the results by the window_7d_start field.
func ByWindow7dStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWindow7dStart, opts...).ToFunc()
}
// ByUserField orders the results by user field. // ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) { return func(s *sql.Selector) {

View File

@@ -115,6 +115,51 @@ func ExpiresAt(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v)) return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v))
} }
// RateLimit5h applies equality check predicate on the "rate_limit_5h" field. It's identical to RateLimit5hEQ.
func RateLimit5h(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldRateLimit5h, v))
}
// RateLimit1d applies equality check predicate on the "rate_limit_1d" field. It's identical to RateLimit1dEQ.
func RateLimit1d(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldRateLimit1d, v))
}
// RateLimit7d applies equality check predicate on the "rate_limit_7d" field. It's identical to RateLimit7dEQ.
func RateLimit7d(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldRateLimit7d, v))
}
// Usage5h applies equality check predicate on the "usage_5h" field. It's identical to Usage5hEQ.
func Usage5h(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldUsage5h, v))
}
// Usage1d applies equality check predicate on the "usage_1d" field. It's identical to Usage1dEQ.
func Usage1d(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldUsage1d, v))
}
// Usage7d applies equality check predicate on the "usage_7d" field. It's identical to Usage7dEQ.
func Usage7d(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldUsage7d, v))
}
// Window5hStart applies equality check predicate on the "window_5h_start" field. It's identical to Window5hStartEQ.
func Window5hStart(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldWindow5hStart, v))
}
// Window1dStart applies equality check predicate on the "window_1d_start" field. It's identical to Window1dStartEQ.
func Window1dStart(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldWindow1dStart, v))
}
// Window7dStart applies equality check predicate on the "window_7d_start" field. It's identical to Window7dStartEQ.
func Window7dStart(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldWindow7dStart, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.APIKey { func CreatedAtEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v)) return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v))
@@ -690,6 +735,396 @@ func ExpiresAtNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt)) return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt))
} }
// RateLimit5hEQ applies the EQ predicate on the "rate_limit_5h" field.
func RateLimit5hEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldRateLimit5h, v))
}
// RateLimit5hNEQ applies the NEQ predicate on the "rate_limit_5h" field.
func RateLimit5hNEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldRateLimit5h, v))
}
// RateLimit5hIn applies the In predicate on the "rate_limit_5h" field.
func RateLimit5hIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldRateLimit5h, vs...))
}
// RateLimit5hNotIn applies the NotIn predicate on the "rate_limit_5h" field.
func RateLimit5hNotIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldRateLimit5h, vs...))
}
// RateLimit5hGT applies the GT predicate on the "rate_limit_5h" field.
func RateLimit5hGT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldRateLimit5h, v))
}
// RateLimit5hGTE applies the GTE predicate on the "rate_limit_5h" field.
func RateLimit5hGTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldRateLimit5h, v))
}
// RateLimit5hLT applies the LT predicate on the "rate_limit_5h" field.
func RateLimit5hLT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldRateLimit5h, v))
}
// RateLimit5hLTE applies the LTE predicate on the "rate_limit_5h" field.
func RateLimit5hLTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldRateLimit5h, v))
}
// RateLimit1dEQ applies the EQ predicate on the "rate_limit_1d" field.
func RateLimit1dEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldRateLimit1d, v))
}
// RateLimit1dNEQ applies the NEQ predicate on the "rate_limit_1d" field.
func RateLimit1dNEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldRateLimit1d, v))
}
// RateLimit1dIn applies the In predicate on the "rate_limit_1d" field.
func RateLimit1dIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldRateLimit1d, vs...))
}
// RateLimit1dNotIn applies the NotIn predicate on the "rate_limit_1d" field.
func RateLimit1dNotIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldRateLimit1d, vs...))
}
// RateLimit1dGT applies the GT predicate on the "rate_limit_1d" field.
func RateLimit1dGT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldRateLimit1d, v))
}
// RateLimit1dGTE applies the GTE predicate on the "rate_limit_1d" field.
func RateLimit1dGTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldRateLimit1d, v))
}
// RateLimit1dLT applies the LT predicate on the "rate_limit_1d" field.
func RateLimit1dLT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldRateLimit1d, v))
}
// RateLimit1dLTE applies the LTE predicate on the "rate_limit_1d" field.
func RateLimit1dLTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldRateLimit1d, v))
}
// RateLimit7dEQ applies the EQ predicate on the "rate_limit_7d" field.
func RateLimit7dEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldRateLimit7d, v))
}
// RateLimit7dNEQ applies the NEQ predicate on the "rate_limit_7d" field.
func RateLimit7dNEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldRateLimit7d, v))
}
// RateLimit7dIn applies the In predicate on the "rate_limit_7d" field.
func RateLimit7dIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldRateLimit7d, vs...))
}
// RateLimit7dNotIn applies the NotIn predicate on the "rate_limit_7d" field.
func RateLimit7dNotIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldRateLimit7d, vs...))
}
// RateLimit7dGT applies the GT predicate on the "rate_limit_7d" field.
func RateLimit7dGT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldRateLimit7d, v))
}
// RateLimit7dGTE applies the GTE predicate on the "rate_limit_7d" field.
func RateLimit7dGTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldRateLimit7d, v))
}
// RateLimit7dLT applies the LT predicate on the "rate_limit_7d" field.
func RateLimit7dLT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldRateLimit7d, v))
}
// RateLimit7dLTE applies the LTE predicate on the "rate_limit_7d" field.
func RateLimit7dLTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldRateLimit7d, v))
}
// Usage5hEQ applies the EQ predicate on the "usage_5h" field.
func Usage5hEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldUsage5h, v))
}
// Usage5hNEQ applies the NEQ predicate on the "usage_5h" field.
func Usage5hNEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldUsage5h, v))
}
// Usage5hIn applies the In predicate on the "usage_5h" field.
func Usage5hIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldUsage5h, vs...))
}
// Usage5hNotIn applies the NotIn predicate on the "usage_5h" field.
func Usage5hNotIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldUsage5h, vs...))
}
// Usage5hGT applies the GT predicate on the "usage_5h" field.
func Usage5hGT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldUsage5h, v))
}
// Usage5hGTE applies the GTE predicate on the "usage_5h" field.
func Usage5hGTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldUsage5h, v))
}
// Usage5hLT applies the LT predicate on the "usage_5h" field.
func Usage5hLT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldUsage5h, v))
}
// Usage5hLTE applies the LTE predicate on the "usage_5h" field.
func Usage5hLTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldUsage5h, v))
}
// Usage1dEQ applies the EQ predicate on the "usage_1d" field.
func Usage1dEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldUsage1d, v))
}
// Usage1dNEQ applies the NEQ predicate on the "usage_1d" field.
func Usage1dNEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldUsage1d, v))
}
// Usage1dIn applies the In predicate on the "usage_1d" field.
func Usage1dIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldUsage1d, vs...))
}
// Usage1dNotIn applies the NotIn predicate on the "usage_1d" field.
func Usage1dNotIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldUsage1d, vs...))
}
// Usage1dGT applies the GT predicate on the "usage_1d" field.
func Usage1dGT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldUsage1d, v))
}
// Usage1dGTE applies the GTE predicate on the "usage_1d" field.
func Usage1dGTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldUsage1d, v))
}
// Usage1dLT applies the LT predicate on the "usage_1d" field.
func Usage1dLT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldUsage1d, v))
}
// Usage1dLTE applies the LTE predicate on the "usage_1d" field.
func Usage1dLTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldUsage1d, v))
}
// Usage7dEQ applies the EQ predicate on the "usage_7d" field.
func Usage7dEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldUsage7d, v))
}
// Usage7dNEQ applies the NEQ predicate on the "usage_7d" field.
func Usage7dNEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldUsage7d, v))
}
// Usage7dIn applies the In predicate on the "usage_7d" field.
func Usage7dIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldUsage7d, vs...))
}
// Usage7dNotIn applies the NotIn predicate on the "usage_7d" field.
func Usage7dNotIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldUsage7d, vs...))
}
// Usage7dGT applies the GT predicate on the "usage_7d" field.
func Usage7dGT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldUsage7d, v))
}
// Usage7dGTE applies the GTE predicate on the "usage_7d" field.
func Usage7dGTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldUsage7d, v))
}
// Usage7dLT applies the LT predicate on the "usage_7d" field.
func Usage7dLT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldUsage7d, v))
}
// Usage7dLTE applies the LTE predicate on the "usage_7d" field.
func Usage7dLTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldUsage7d, v))
}
// Window5hStartEQ applies the EQ predicate on the "window_5h_start" field.
func Window5hStartEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldWindow5hStart, v))
}
// Window5hStartNEQ applies the NEQ predicate on the "window_5h_start" field.
func Window5hStartNEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldWindow5hStart, v))
}
// Window5hStartIn applies the In predicate on the "window_5h_start" field.
func Window5hStartIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldWindow5hStart, vs...))
}
// Window5hStartNotIn applies the NotIn predicate on the "window_5h_start" field.
func Window5hStartNotIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldWindow5hStart, vs...))
}
// Window5hStartGT applies the GT predicate on the "window_5h_start" field.
func Window5hStartGT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldWindow5hStart, v))
}
// Window5hStartGTE applies the GTE predicate on the "window_5h_start" field.
func Window5hStartGTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldWindow5hStart, v))
}
// Window5hStartLT applies the LT predicate on the "window_5h_start" field.
func Window5hStartLT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldWindow5hStart, v))
}
// Window5hStartLTE applies the LTE predicate on the "window_5h_start" field.
func Window5hStartLTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldWindow5hStart, v))
}
// Window5hStartIsNil applies the IsNil predicate on the "window_5h_start" field.
func Window5hStartIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldWindow5hStart))
}
// Window5hStartNotNil applies the NotNil predicate on the "window_5h_start" field.
func Window5hStartNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldWindow5hStart))
}
// Window1dStartEQ applies the EQ predicate on the "window_1d_start" field.
func Window1dStartEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldWindow1dStart, v))
}
// Window1dStartNEQ applies the NEQ predicate on the "window_1d_start" field.
func Window1dStartNEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldWindow1dStart, v))
}
// Window1dStartIn applies the In predicate on the "window_1d_start" field.
func Window1dStartIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldWindow1dStart, vs...))
}
// Window1dStartNotIn applies the NotIn predicate on the "window_1d_start" field.
func Window1dStartNotIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldWindow1dStart, vs...))
}
// Window1dStartGT applies the GT predicate on the "window_1d_start" field.
func Window1dStartGT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldWindow1dStart, v))
}
// Window1dStartGTE applies the GTE predicate on the "window_1d_start" field.
func Window1dStartGTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldWindow1dStart, v))
}
// Window1dStartLT applies the LT predicate on the "window_1d_start" field.
func Window1dStartLT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldWindow1dStart, v))
}
// Window1dStartLTE applies the LTE predicate on the "window_1d_start" field.
func Window1dStartLTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldWindow1dStart, v))
}
// Window1dStartIsNil applies the IsNil predicate on the "window_1d_start" field.
func Window1dStartIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldWindow1dStart))
}
// Window1dStartNotNil applies the NotNil predicate on the "window_1d_start" field.
func Window1dStartNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldWindow1dStart))
}
// Window7dStartEQ applies the EQ predicate on the "window_7d_start" field.
func Window7dStartEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldWindow7dStart, v))
}
// Window7dStartNEQ applies the NEQ predicate on the "window_7d_start" field.
func Window7dStartNEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldWindow7dStart, v))
}
// Window7dStartIn applies the In predicate on the "window_7d_start" field.
func Window7dStartIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldWindow7dStart, vs...))
}
// Window7dStartNotIn applies the NotIn predicate on the "window_7d_start" field.
func Window7dStartNotIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldWindow7dStart, vs...))
}
// Window7dStartGT applies the GT predicate on the "window_7d_start" field.
func Window7dStartGT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldWindow7dStart, v))
}
// Window7dStartGTE applies the GTE predicate on the "window_7d_start" field.
func Window7dStartGTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldWindow7dStart, v))
}
// Window7dStartLT applies the LT predicate on the "window_7d_start" field.
func Window7dStartLT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldWindow7dStart, v))
}
// Window7dStartLTE applies the LTE predicate on the "window_7d_start" field.
func Window7dStartLTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldWindow7dStart, v))
}
// Window7dStartIsNil applies the IsNil predicate on the "window_7d_start" field.
func Window7dStartIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldWindow7dStart))
}
// Window7dStartNotNil applies the NotNil predicate on the "window_7d_start" field.
func Window7dStartNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldWindow7dStart))
}
// HasUser applies the HasEdge predicate on the "user" edge. // HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.APIKey { func HasUser() predicate.APIKey {
return predicate.APIKey(func(s *sql.Selector) { return predicate.APIKey(func(s *sql.Selector) {

View File

@@ -181,6 +181,132 @@ func (_c *APIKeyCreate) SetNillableExpiresAt(v *time.Time) *APIKeyCreate {
return _c return _c
} }
// SetRateLimit5h sets the "rate_limit_5h" field.
func (_c *APIKeyCreate) SetRateLimit5h(v float64) *APIKeyCreate {
_c.mutation.SetRateLimit5h(v)
return _c
}
// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableRateLimit5h(v *float64) *APIKeyCreate {
if v != nil {
_c.SetRateLimit5h(*v)
}
return _c
}
// SetRateLimit1d sets the "rate_limit_1d" field.
func (_c *APIKeyCreate) SetRateLimit1d(v float64) *APIKeyCreate {
_c.mutation.SetRateLimit1d(v)
return _c
}
// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableRateLimit1d(v *float64) *APIKeyCreate {
if v != nil {
_c.SetRateLimit1d(*v)
}
return _c
}
// SetRateLimit7d sets the "rate_limit_7d" field.
func (_c *APIKeyCreate) SetRateLimit7d(v float64) *APIKeyCreate {
_c.mutation.SetRateLimit7d(v)
return _c
}
// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableRateLimit7d(v *float64) *APIKeyCreate {
if v != nil {
_c.SetRateLimit7d(*v)
}
return _c
}
// SetUsage5h sets the "usage_5h" field.
func (_c *APIKeyCreate) SetUsage5h(v float64) *APIKeyCreate {
_c.mutation.SetUsage5h(v)
return _c
}
// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableUsage5h(v *float64) *APIKeyCreate {
if v != nil {
_c.SetUsage5h(*v)
}
return _c
}
// SetUsage1d sets the "usage_1d" field.
func (_c *APIKeyCreate) SetUsage1d(v float64) *APIKeyCreate {
_c.mutation.SetUsage1d(v)
return _c
}
// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableUsage1d(v *float64) *APIKeyCreate {
if v != nil {
_c.SetUsage1d(*v)
}
return _c
}
// SetUsage7d sets the "usage_7d" field.
func (_c *APIKeyCreate) SetUsage7d(v float64) *APIKeyCreate {
_c.mutation.SetUsage7d(v)
return _c
}
// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableUsage7d(v *float64) *APIKeyCreate {
if v != nil {
_c.SetUsage7d(*v)
}
return _c
}
// SetWindow5hStart sets the "window_5h_start" field.
func (_c *APIKeyCreate) SetWindow5hStart(v time.Time) *APIKeyCreate {
_c.mutation.SetWindow5hStart(v)
return _c
}
// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableWindow5hStart(v *time.Time) *APIKeyCreate {
if v != nil {
_c.SetWindow5hStart(*v)
}
return _c
}
// SetWindow1dStart sets the "window_1d_start" field.
func (_c *APIKeyCreate) SetWindow1dStart(v time.Time) *APIKeyCreate {
_c.mutation.SetWindow1dStart(v)
return _c
}
// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableWindow1dStart(v *time.Time) *APIKeyCreate {
if v != nil {
_c.SetWindow1dStart(*v)
}
return _c
}
// SetWindow7dStart sets the "window_7d_start" field.
func (_c *APIKeyCreate) SetWindow7dStart(v time.Time) *APIKeyCreate {
_c.mutation.SetWindow7dStart(v)
return _c
}
// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableWindow7dStart(v *time.Time) *APIKeyCreate {
if v != nil {
_c.SetWindow7dStart(*v)
}
return _c
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
return _c.SetUserID(v.ID) return _c.SetUserID(v.ID)
@@ -269,6 +395,30 @@ func (_c *APIKeyCreate) defaults() error {
v := apikey.DefaultQuotaUsed v := apikey.DefaultQuotaUsed
_c.mutation.SetQuotaUsed(v) _c.mutation.SetQuotaUsed(v)
} }
if _, ok := _c.mutation.RateLimit5h(); !ok {
v := apikey.DefaultRateLimit5h
_c.mutation.SetRateLimit5h(v)
}
if _, ok := _c.mutation.RateLimit1d(); !ok {
v := apikey.DefaultRateLimit1d
_c.mutation.SetRateLimit1d(v)
}
if _, ok := _c.mutation.RateLimit7d(); !ok {
v := apikey.DefaultRateLimit7d
_c.mutation.SetRateLimit7d(v)
}
if _, ok := _c.mutation.Usage5h(); !ok {
v := apikey.DefaultUsage5h
_c.mutation.SetUsage5h(v)
}
if _, ok := _c.mutation.Usage1d(); !ok {
v := apikey.DefaultUsage1d
_c.mutation.SetUsage1d(v)
}
if _, ok := _c.mutation.Usage7d(); !ok {
v := apikey.DefaultUsage7d
_c.mutation.SetUsage7d(v)
}
return nil return nil
} }
@@ -313,6 +463,24 @@ func (_c *APIKeyCreate) check() error {
if _, ok := _c.mutation.QuotaUsed(); !ok { if _, ok := _c.mutation.QuotaUsed(); !ok {
return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)} return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)}
} }
if _, ok := _c.mutation.RateLimit5h(); !ok {
return &ValidationError{Name: "rate_limit_5h", err: errors.New(`ent: missing required field "APIKey.rate_limit_5h"`)}
}
if _, ok := _c.mutation.RateLimit1d(); !ok {
return &ValidationError{Name: "rate_limit_1d", err: errors.New(`ent: missing required field "APIKey.rate_limit_1d"`)}
}
if _, ok := _c.mutation.RateLimit7d(); !ok {
return &ValidationError{Name: "rate_limit_7d", err: errors.New(`ent: missing required field "APIKey.rate_limit_7d"`)}
}
if _, ok := _c.mutation.Usage5h(); !ok {
return &ValidationError{Name: "usage_5h", err: errors.New(`ent: missing required field "APIKey.usage_5h"`)}
}
if _, ok := _c.mutation.Usage1d(); !ok {
return &ValidationError{Name: "usage_1d", err: errors.New(`ent: missing required field "APIKey.usage_1d"`)}
}
if _, ok := _c.mutation.Usage7d(); !ok {
return &ValidationError{Name: "usage_7d", err: errors.New(`ent: missing required field "APIKey.usage_7d"`)}
}
if len(_c.mutation.UserIDs()) == 0 { if len(_c.mutation.UserIDs()) == 0 {
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)} return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)}
} }
@@ -391,6 +559,42 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
_spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value)
_node.ExpiresAt = &value _node.ExpiresAt = &value
} }
if value, ok := _c.mutation.RateLimit5h(); ok {
_spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value)
_node.RateLimit5h = value
}
if value, ok := _c.mutation.RateLimit1d(); ok {
_spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value)
_node.RateLimit1d = value
}
if value, ok := _c.mutation.RateLimit7d(); ok {
_spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value)
_node.RateLimit7d = value
}
if value, ok := _c.mutation.Usage5h(); ok {
_spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value)
_node.Usage5h = value
}
if value, ok := _c.mutation.Usage1d(); ok {
_spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value)
_node.Usage1d = value
}
if value, ok := _c.mutation.Usage7d(); ok {
_spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value)
_node.Usage7d = value
}
if value, ok := _c.mutation.Window5hStart(); ok {
_spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value)
_node.Window5hStart = &value
}
if value, ok := _c.mutation.Window1dStart(); ok {
_spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value)
_node.Window1dStart = &value
}
if value, ok := _c.mutation.Window7dStart(); ok {
_spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value)
_node.Window7dStart = &value
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
@@ -697,6 +901,168 @@ func (u *APIKeyUpsert) ClearExpiresAt() *APIKeyUpsert {
return u return u
} }
// SetRateLimit5h sets the "rate_limit_5h" field.
func (u *APIKeyUpsert) SetRateLimit5h(v float64) *APIKeyUpsert {
u.Set(apikey.FieldRateLimit5h, v)
return u
}
// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateRateLimit5h() *APIKeyUpsert {
u.SetExcluded(apikey.FieldRateLimit5h)
return u
}
// AddRateLimit5h adds v to the "rate_limit_5h" field.
func (u *APIKeyUpsert) AddRateLimit5h(v float64) *APIKeyUpsert {
u.Add(apikey.FieldRateLimit5h, v)
return u
}
// SetRateLimit1d sets the "rate_limit_1d" field.
func (u *APIKeyUpsert) SetRateLimit1d(v float64) *APIKeyUpsert {
u.Set(apikey.FieldRateLimit1d, v)
return u
}
// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateRateLimit1d() *APIKeyUpsert {
u.SetExcluded(apikey.FieldRateLimit1d)
return u
}
// AddRateLimit1d adds v to the "rate_limit_1d" field.
func (u *APIKeyUpsert) AddRateLimit1d(v float64) *APIKeyUpsert {
u.Add(apikey.FieldRateLimit1d, v)
return u
}
// SetRateLimit7d sets the "rate_limit_7d" field.
func (u *APIKeyUpsert) SetRateLimit7d(v float64) *APIKeyUpsert {
u.Set(apikey.FieldRateLimit7d, v)
return u
}
// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateRateLimit7d() *APIKeyUpsert {
u.SetExcluded(apikey.FieldRateLimit7d)
return u
}
// AddRateLimit7d adds v to the "rate_limit_7d" field.
func (u *APIKeyUpsert) AddRateLimit7d(v float64) *APIKeyUpsert {
u.Add(apikey.FieldRateLimit7d, v)
return u
}
// SetUsage5h sets the "usage_5h" field.
func (u *APIKeyUpsert) SetUsage5h(v float64) *APIKeyUpsert {
u.Set(apikey.FieldUsage5h, v)
return u
}
// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateUsage5h() *APIKeyUpsert {
u.SetExcluded(apikey.FieldUsage5h)
return u
}
// AddUsage5h adds v to the "usage_5h" field.
func (u *APIKeyUpsert) AddUsage5h(v float64) *APIKeyUpsert {
u.Add(apikey.FieldUsage5h, v)
return u
}
// SetUsage1d sets the "usage_1d" field.
func (u *APIKeyUpsert) SetUsage1d(v float64) *APIKeyUpsert {
u.Set(apikey.FieldUsage1d, v)
return u
}
// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateUsage1d() *APIKeyUpsert {
u.SetExcluded(apikey.FieldUsage1d)
return u
}
// AddUsage1d adds v to the "usage_1d" field.
func (u *APIKeyUpsert) AddUsage1d(v float64) *APIKeyUpsert {
u.Add(apikey.FieldUsage1d, v)
return u
}
// SetUsage7d sets the "usage_7d" field.
func (u *APIKeyUpsert) SetUsage7d(v float64) *APIKeyUpsert {
u.Set(apikey.FieldUsage7d, v)
return u
}
// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateUsage7d() *APIKeyUpsert {
u.SetExcluded(apikey.FieldUsage7d)
return u
}
// AddUsage7d adds v to the "usage_7d" field.
func (u *APIKeyUpsert) AddUsage7d(v float64) *APIKeyUpsert {
u.Add(apikey.FieldUsage7d, v)
return u
}
// SetWindow5hStart sets the "window_5h_start" field.
func (u *APIKeyUpsert) SetWindow5hStart(v time.Time) *APIKeyUpsert {
u.Set(apikey.FieldWindow5hStart, v)
return u
}
// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateWindow5hStart() *APIKeyUpsert {
u.SetExcluded(apikey.FieldWindow5hStart)
return u
}
// ClearWindow5hStart clears the value of the "window_5h_start" field.
func (u *APIKeyUpsert) ClearWindow5hStart() *APIKeyUpsert {
u.SetNull(apikey.FieldWindow5hStart)
return u
}
// SetWindow1dStart sets the "window_1d_start" field.
func (u *APIKeyUpsert) SetWindow1dStart(v time.Time) *APIKeyUpsert {
u.Set(apikey.FieldWindow1dStart, v)
return u
}
// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateWindow1dStart() *APIKeyUpsert {
u.SetExcluded(apikey.FieldWindow1dStart)
return u
}
// ClearWindow1dStart clears the value of the "window_1d_start" field.
func (u *APIKeyUpsert) ClearWindow1dStart() *APIKeyUpsert {
u.SetNull(apikey.FieldWindow1dStart)
return u
}
// SetWindow7dStart sets the "window_7d_start" field.
func (u *APIKeyUpsert) SetWindow7dStart(v time.Time) *APIKeyUpsert {
u.Set(apikey.FieldWindow7dStart, v)
return u
}
// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateWindow7dStart() *APIKeyUpsert {
u.SetExcluded(apikey.FieldWindow7dStart)
return u
}
// ClearWindow7dStart clears the value of the "window_7d_start" field.
func (u *APIKeyUpsert) ClearWindow7dStart() *APIKeyUpsert {
u.SetNull(apikey.FieldWindow7dStart)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
@@ -980,6 +1346,195 @@ func (u *APIKeyUpsertOne) ClearExpiresAt() *APIKeyUpsertOne {
}) })
} }
// SetRateLimit5h sets the "rate_limit_5h" field.
func (u *APIKeyUpsertOne) SetRateLimit5h(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetRateLimit5h(v)
})
}
// AddRateLimit5h adds v to the "rate_limit_5h" field.
func (u *APIKeyUpsertOne) AddRateLimit5h(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.AddRateLimit5h(v)
})
}
// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateRateLimit5h() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateRateLimit5h()
})
}
// SetRateLimit1d sets the "rate_limit_1d" field.
func (u *APIKeyUpsertOne) SetRateLimit1d(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetRateLimit1d(v)
})
}
// AddRateLimit1d adds v to the "rate_limit_1d" field.
func (u *APIKeyUpsertOne) AddRateLimit1d(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.AddRateLimit1d(v)
})
}
// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateRateLimit1d() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateRateLimit1d()
})
}
// SetRateLimit7d sets the "rate_limit_7d" field.
func (u *APIKeyUpsertOne) SetRateLimit7d(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetRateLimit7d(v)
})
}
// AddRateLimit7d adds v to the "rate_limit_7d" field.
func (u *APIKeyUpsertOne) AddRateLimit7d(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.AddRateLimit7d(v)
})
}
// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateRateLimit7d() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateRateLimit7d()
})
}
// SetUsage5h sets the "usage_5h" field.
func (u *APIKeyUpsertOne) SetUsage5h(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetUsage5h(v)
})
}
// AddUsage5h adds v to the "usage_5h" field.
func (u *APIKeyUpsertOne) AddUsage5h(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.AddUsage5h(v)
})
}
// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateUsage5h() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateUsage5h()
})
}
// SetUsage1d sets the "usage_1d" field.
func (u *APIKeyUpsertOne) SetUsage1d(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetUsage1d(v)
})
}
// AddUsage1d adds v to the "usage_1d" field.
func (u *APIKeyUpsertOne) AddUsage1d(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.AddUsage1d(v)
})
}
// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateUsage1d() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateUsage1d()
})
}
// SetUsage7d sets the "usage_7d" field.
func (u *APIKeyUpsertOne) SetUsage7d(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetUsage7d(v)
})
}
// AddUsage7d adds v to the "usage_7d" field.
func (u *APIKeyUpsertOne) AddUsage7d(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.AddUsage7d(v)
})
}
// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateUsage7d() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateUsage7d()
})
}
// SetWindow5hStart sets the "window_5h_start" field.
func (u *APIKeyUpsertOne) SetWindow5hStart(v time.Time) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetWindow5hStart(v)
})
}
// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateWindow5hStart() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateWindow5hStart()
})
}
// ClearWindow5hStart clears the value of the "window_5h_start" field.
func (u *APIKeyUpsertOne) ClearWindow5hStart() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearWindow5hStart()
})
}
// SetWindow1dStart sets the "window_1d_start" field.
func (u *APIKeyUpsertOne) SetWindow1dStart(v time.Time) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetWindow1dStart(v)
})
}
// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateWindow1dStart() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateWindow1dStart()
})
}
// ClearWindow1dStart clears the value of the "window_1d_start" field.
func (u *APIKeyUpsertOne) ClearWindow1dStart() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearWindow1dStart()
})
}
// SetWindow7dStart sets the "window_7d_start" field.
func (u *APIKeyUpsertOne) SetWindow7dStart(v time.Time) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetWindow7dStart(v)
})
}
// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateWindow7dStart() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateWindow7dStart()
})
}
// ClearWindow7dStart clears the value of the "window_7d_start" field.
func (u *APIKeyUpsertOne) ClearWindow7dStart() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearWindow7dStart()
})
}
// Exec executes the query. // Exec executes the query.
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
@@ -1429,6 +1984,195 @@ func (u *APIKeyUpsertBulk) ClearExpiresAt() *APIKeyUpsertBulk {
}) })
} }
// SetRateLimit5h sets the "rate_limit_5h" field.
func (u *APIKeyUpsertBulk) SetRateLimit5h(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetRateLimit5h(v)
})
}
// AddRateLimit5h adds v to the "rate_limit_5h" field.
func (u *APIKeyUpsertBulk) AddRateLimit5h(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.AddRateLimit5h(v)
})
}
// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateRateLimit5h() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateRateLimit5h()
})
}
// SetRateLimit1d sets the "rate_limit_1d" field.
func (u *APIKeyUpsertBulk) SetRateLimit1d(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetRateLimit1d(v)
})
}
// AddRateLimit1d adds v to the "rate_limit_1d" field.
func (u *APIKeyUpsertBulk) AddRateLimit1d(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.AddRateLimit1d(v)
})
}
// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateRateLimit1d() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateRateLimit1d()
})
}
// SetRateLimit7d sets the "rate_limit_7d" field.
func (u *APIKeyUpsertBulk) SetRateLimit7d(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetRateLimit7d(v)
})
}
// AddRateLimit7d adds v to the "rate_limit_7d" field.
func (u *APIKeyUpsertBulk) AddRateLimit7d(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.AddRateLimit7d(v)
})
}
// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateRateLimit7d() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateRateLimit7d()
})
}
// SetUsage5h sets the "usage_5h" field.
func (u *APIKeyUpsertBulk) SetUsage5h(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetUsage5h(v)
})
}
// AddUsage5h adds v to the "usage_5h" field.
func (u *APIKeyUpsertBulk) AddUsage5h(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.AddUsage5h(v)
})
}
// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateUsage5h() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateUsage5h()
})
}
// SetUsage1d sets the "usage_1d" field.
func (u *APIKeyUpsertBulk) SetUsage1d(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetUsage1d(v)
})
}
// AddUsage1d adds v to the "usage_1d" field.
func (u *APIKeyUpsertBulk) AddUsage1d(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.AddUsage1d(v)
})
}
// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateUsage1d() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateUsage1d()
})
}
// SetUsage7d sets the "usage_7d" field.
func (u *APIKeyUpsertBulk) SetUsage7d(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetUsage7d(v)
})
}
// AddUsage7d adds v to the "usage_7d" field.
func (u *APIKeyUpsertBulk) AddUsage7d(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.AddUsage7d(v)
})
}
// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateUsage7d() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateUsage7d()
})
}
// SetWindow5hStart sets the "window_5h_start" field.
func (u *APIKeyUpsertBulk) SetWindow5hStart(v time.Time) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetWindow5hStart(v)
})
}
// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateWindow5hStart() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateWindow5hStart()
})
}
// ClearWindow5hStart clears the value of the "window_5h_start" field.
func (u *APIKeyUpsertBulk) ClearWindow5hStart() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearWindow5hStart()
})
}
// SetWindow1dStart sets the "window_1d_start" field.
func (u *APIKeyUpsertBulk) SetWindow1dStart(v time.Time) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetWindow1dStart(v)
})
}
// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateWindow1dStart() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateWindow1dStart()
})
}
// ClearWindow1dStart clears the value of the "window_1d_start" field.
func (u *APIKeyUpsertBulk) ClearWindow1dStart() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearWindow1dStart()
})
}
// SetWindow7dStart sets the "window_7d_start" field.
func (u *APIKeyUpsertBulk) SetWindow7dStart(v time.Time) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetWindow7dStart(v)
})
}
// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateWindow7dStart() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateWindow7dStart()
})
}
// ClearWindow7dStart clears the value of the "window_7d_start" field.
func (u *APIKeyUpsertBulk) ClearWindow7dStart() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearWindow7dStart()
})
}
// Exec executes the query. // Exec executes the query.
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {

View File

@@ -252,6 +252,192 @@ func (_u *APIKeyUpdate) ClearExpiresAt() *APIKeyUpdate {
return _u return _u
} }
// SetRateLimit5h sets the "rate_limit_5h" field.
func (_u *APIKeyUpdate) SetRateLimit5h(v float64) *APIKeyUpdate {
_u.mutation.ResetRateLimit5h()
_u.mutation.SetRateLimit5h(v)
return _u
}
// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableRateLimit5h(v *float64) *APIKeyUpdate {
if v != nil {
_u.SetRateLimit5h(*v)
}
return _u
}
// AddRateLimit5h adds value to the "rate_limit_5h" field.
func (_u *APIKeyUpdate) AddRateLimit5h(v float64) *APIKeyUpdate {
_u.mutation.AddRateLimit5h(v)
return _u
}
// SetRateLimit1d sets the "rate_limit_1d" field.
func (_u *APIKeyUpdate) SetRateLimit1d(v float64) *APIKeyUpdate {
_u.mutation.ResetRateLimit1d()
_u.mutation.SetRateLimit1d(v)
return _u
}
// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableRateLimit1d(v *float64) *APIKeyUpdate {
if v != nil {
_u.SetRateLimit1d(*v)
}
return _u
}
// AddRateLimit1d adds value to the "rate_limit_1d" field.
func (_u *APIKeyUpdate) AddRateLimit1d(v float64) *APIKeyUpdate {
_u.mutation.AddRateLimit1d(v)
return _u
}
// SetRateLimit7d sets the "rate_limit_7d" field.
func (_u *APIKeyUpdate) SetRateLimit7d(v float64) *APIKeyUpdate {
_u.mutation.ResetRateLimit7d()
_u.mutation.SetRateLimit7d(v)
return _u
}
// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableRateLimit7d(v *float64) *APIKeyUpdate {
if v != nil {
_u.SetRateLimit7d(*v)
}
return _u
}
// AddRateLimit7d adds value to the "rate_limit_7d" field.
func (_u *APIKeyUpdate) AddRateLimit7d(v float64) *APIKeyUpdate {
_u.mutation.AddRateLimit7d(v)
return _u
}
// SetUsage5h sets the "usage_5h" field.
func (_u *APIKeyUpdate) SetUsage5h(v float64) *APIKeyUpdate {
_u.mutation.ResetUsage5h()
_u.mutation.SetUsage5h(v)
return _u
}
// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableUsage5h(v *float64) *APIKeyUpdate {
if v != nil {
_u.SetUsage5h(*v)
}
return _u
}
// AddUsage5h adds value to the "usage_5h" field.
func (_u *APIKeyUpdate) AddUsage5h(v float64) *APIKeyUpdate {
_u.mutation.AddUsage5h(v)
return _u
}
// SetUsage1d sets the "usage_1d" field.
func (_u *APIKeyUpdate) SetUsage1d(v float64) *APIKeyUpdate {
_u.mutation.ResetUsage1d()
_u.mutation.SetUsage1d(v)
return _u
}
// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableUsage1d(v *float64) *APIKeyUpdate {
if v != nil {
_u.SetUsage1d(*v)
}
return _u
}
// AddUsage1d adds value to the "usage_1d" field.
func (_u *APIKeyUpdate) AddUsage1d(v float64) *APIKeyUpdate {
_u.mutation.AddUsage1d(v)
return _u
}
// SetUsage7d sets the "usage_7d" field.
func (_u *APIKeyUpdate) SetUsage7d(v float64) *APIKeyUpdate {
_u.mutation.ResetUsage7d()
_u.mutation.SetUsage7d(v)
return _u
}
// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableUsage7d(v *float64) *APIKeyUpdate {
if v != nil {
_u.SetUsage7d(*v)
}
return _u
}
// AddUsage7d adds value to the "usage_7d" field.
func (_u *APIKeyUpdate) AddUsage7d(v float64) *APIKeyUpdate {
_u.mutation.AddUsage7d(v)
return _u
}
// SetWindow5hStart sets the "window_5h_start" field.
func (_u *APIKeyUpdate) SetWindow5hStart(v time.Time) *APIKeyUpdate {
_u.mutation.SetWindow5hStart(v)
return _u
}
// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableWindow5hStart(v *time.Time) *APIKeyUpdate {
if v != nil {
_u.SetWindow5hStart(*v)
}
return _u
}
// ClearWindow5hStart clears the value of the "window_5h_start" field.
func (_u *APIKeyUpdate) ClearWindow5hStart() *APIKeyUpdate {
_u.mutation.ClearWindow5hStart()
return _u
}
// SetWindow1dStart sets the "window_1d_start" field.
func (_u *APIKeyUpdate) SetWindow1dStart(v time.Time) *APIKeyUpdate {
_u.mutation.SetWindow1dStart(v)
return _u
}
// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableWindow1dStart(v *time.Time) *APIKeyUpdate {
if v != nil {
_u.SetWindow1dStart(*v)
}
return _u
}
// ClearWindow1dStart clears the value of the "window_1d_start" field.
func (_u *APIKeyUpdate) ClearWindow1dStart() *APIKeyUpdate {
_u.mutation.ClearWindow1dStart()
return _u
}
// SetWindow7dStart sets the "window_7d_start" field.
func (_u *APIKeyUpdate) SetWindow7dStart(v time.Time) *APIKeyUpdate {
_u.mutation.SetWindow7dStart(v)
return _u
}
// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableWindow7dStart(v *time.Time) *APIKeyUpdate {
if v != nil {
_u.SetWindow7dStart(*v)
}
return _u
}
// ClearWindow7dStart clears the value of the "window_7d_start" field.
func (_u *APIKeyUpdate) ClearWindow7dStart() *APIKeyUpdate {
_u.mutation.ClearWindow7dStart()
return _u
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
return _u.SetUserID(v.ID) return _u.SetUserID(v.ID)
@@ -456,6 +642,60 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ExpiresAtCleared() { if _u.mutation.ExpiresAtCleared() {
_spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime)
} }
if value, ok := _u.mutation.RateLimit5h(); ok {
_spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateLimit5h(); ok {
_spec.AddField(apikey.FieldRateLimit5h, field.TypeFloat64, value)
}
if value, ok := _u.mutation.RateLimit1d(); ok {
_spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateLimit1d(); ok {
_spec.AddField(apikey.FieldRateLimit1d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.RateLimit7d(); ok {
_spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateLimit7d(); ok {
_spec.AddField(apikey.FieldRateLimit7d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Usage5h(); ok {
_spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedUsage5h(); ok {
_spec.AddField(apikey.FieldUsage5h, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Usage1d(); ok {
_spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedUsage1d(); ok {
_spec.AddField(apikey.FieldUsage1d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Usage7d(); ok {
_spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedUsage7d(); ok {
_spec.AddField(apikey.FieldUsage7d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Window5hStart(); ok {
_spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value)
}
if _u.mutation.Window5hStartCleared() {
_spec.ClearField(apikey.FieldWindow5hStart, field.TypeTime)
}
if value, ok := _u.mutation.Window1dStart(); ok {
_spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value)
}
if _u.mutation.Window1dStartCleared() {
_spec.ClearField(apikey.FieldWindow1dStart, field.TypeTime)
}
if value, ok := _u.mutation.Window7dStart(); ok {
_spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value)
}
if _u.mutation.Window7dStartCleared() {
_spec.ClearField(apikey.FieldWindow7dStart, field.TypeTime)
}
if _u.mutation.UserCleared() { if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
@@ -799,6 +1039,192 @@ func (_u *APIKeyUpdateOne) ClearExpiresAt() *APIKeyUpdateOne {
return _u return _u
} }
// SetRateLimit5h sets the "rate_limit_5h" field.
func (_u *APIKeyUpdateOne) SetRateLimit5h(v float64) *APIKeyUpdateOne {
_u.mutation.ResetRateLimit5h()
_u.mutation.SetRateLimit5h(v)
return _u
}
// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableRateLimit5h(v *float64) *APIKeyUpdateOne {
if v != nil {
_u.SetRateLimit5h(*v)
}
return _u
}
// AddRateLimit5h adds value to the "rate_limit_5h" field.
func (_u *APIKeyUpdateOne) AddRateLimit5h(v float64) *APIKeyUpdateOne {
_u.mutation.AddRateLimit5h(v)
return _u
}
// SetRateLimit1d sets the "rate_limit_1d" field.
func (_u *APIKeyUpdateOne) SetRateLimit1d(v float64) *APIKeyUpdateOne {
_u.mutation.ResetRateLimit1d()
_u.mutation.SetRateLimit1d(v)
return _u
}
// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableRateLimit1d(v *float64) *APIKeyUpdateOne {
if v != nil {
_u.SetRateLimit1d(*v)
}
return _u
}
// AddRateLimit1d adds value to the "rate_limit_1d" field.
func (_u *APIKeyUpdateOne) AddRateLimit1d(v float64) *APIKeyUpdateOne {
_u.mutation.AddRateLimit1d(v)
return _u
}
// SetRateLimit7d sets the "rate_limit_7d" field.
func (_u *APIKeyUpdateOne) SetRateLimit7d(v float64) *APIKeyUpdateOne {
_u.mutation.ResetRateLimit7d()
_u.mutation.SetRateLimit7d(v)
return _u
}
// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableRateLimit7d(v *float64) *APIKeyUpdateOne {
if v != nil {
_u.SetRateLimit7d(*v)
}
return _u
}
// AddRateLimit7d adds value to the "rate_limit_7d" field.
func (_u *APIKeyUpdateOne) AddRateLimit7d(v float64) *APIKeyUpdateOne {
_u.mutation.AddRateLimit7d(v)
return _u
}
// SetUsage5h sets the "usage_5h" field.
func (_u *APIKeyUpdateOne) SetUsage5h(v float64) *APIKeyUpdateOne {
_u.mutation.ResetUsage5h()
_u.mutation.SetUsage5h(v)
return _u
}
// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableUsage5h(v *float64) *APIKeyUpdateOne {
if v != nil {
_u.SetUsage5h(*v)
}
return _u
}
// AddUsage5h adds value to the "usage_5h" field.
func (_u *APIKeyUpdateOne) AddUsage5h(v float64) *APIKeyUpdateOne {
_u.mutation.AddUsage5h(v)
return _u
}
// SetUsage1d sets the "usage_1d" field.
func (_u *APIKeyUpdateOne) SetUsage1d(v float64) *APIKeyUpdateOne {
_u.mutation.ResetUsage1d()
_u.mutation.SetUsage1d(v)
return _u
}
// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableUsage1d(v *float64) *APIKeyUpdateOne {
if v != nil {
_u.SetUsage1d(*v)
}
return _u
}
// AddUsage1d adds value to the "usage_1d" field.
func (_u *APIKeyUpdateOne) AddUsage1d(v float64) *APIKeyUpdateOne {
_u.mutation.AddUsage1d(v)
return _u
}
// SetUsage7d sets the "usage_7d" field.
func (_u *APIKeyUpdateOne) SetUsage7d(v float64) *APIKeyUpdateOne {
_u.mutation.ResetUsage7d()
_u.mutation.SetUsage7d(v)
return _u
}
// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableUsage7d(v *float64) *APIKeyUpdateOne {
if v != nil {
_u.SetUsage7d(*v)
}
return _u
}
// AddUsage7d adds value to the "usage_7d" field.
func (_u *APIKeyUpdateOne) AddUsage7d(v float64) *APIKeyUpdateOne {
_u.mutation.AddUsage7d(v)
return _u
}
// SetWindow5hStart sets the "window_5h_start" field.
func (_u *APIKeyUpdateOne) SetWindow5hStart(v time.Time) *APIKeyUpdateOne {
_u.mutation.SetWindow5hStart(v)
return _u
}
// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableWindow5hStart(v *time.Time) *APIKeyUpdateOne {
if v != nil {
_u.SetWindow5hStart(*v)
}
return _u
}
// ClearWindow5hStart clears the value of the "window_5h_start" field.
func (_u *APIKeyUpdateOne) ClearWindow5hStart() *APIKeyUpdateOne {
_u.mutation.ClearWindow5hStart()
return _u
}
// SetWindow1dStart sets the "window_1d_start" field.
func (_u *APIKeyUpdateOne) SetWindow1dStart(v time.Time) *APIKeyUpdateOne {
_u.mutation.SetWindow1dStart(v)
return _u
}
// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableWindow1dStart(v *time.Time) *APIKeyUpdateOne {
if v != nil {
_u.SetWindow1dStart(*v)
}
return _u
}
// ClearWindow1dStart clears the value of the "window_1d_start" field.
func (_u *APIKeyUpdateOne) ClearWindow1dStart() *APIKeyUpdateOne {
_u.mutation.ClearWindow1dStart()
return _u
}
// SetWindow7dStart sets the "window_7d_start" field.
func (_u *APIKeyUpdateOne) SetWindow7dStart(v time.Time) *APIKeyUpdateOne {
_u.mutation.SetWindow7dStart(v)
return _u
}
// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableWindow7dStart(v *time.Time) *APIKeyUpdateOne {
if v != nil {
_u.SetWindow7dStart(*v)
}
return _u
}
// ClearWindow7dStart clears the value of the "window_7d_start" field.
func (_u *APIKeyUpdateOne) ClearWindow7dStart() *APIKeyUpdateOne {
_u.mutation.ClearWindow7dStart()
return _u
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
return _u.SetUserID(v.ID) return _u.SetUserID(v.ID)
@@ -1033,6 +1459,60 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
if _u.mutation.ExpiresAtCleared() { if _u.mutation.ExpiresAtCleared() {
_spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime)
} }
if value, ok := _u.mutation.RateLimit5h(); ok {
_spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateLimit5h(); ok {
_spec.AddField(apikey.FieldRateLimit5h, field.TypeFloat64, value)
}
if value, ok := _u.mutation.RateLimit1d(); ok {
_spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateLimit1d(); ok {
_spec.AddField(apikey.FieldRateLimit1d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.RateLimit7d(); ok {
_spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateLimit7d(); ok {
_spec.AddField(apikey.FieldRateLimit7d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Usage5h(); ok {
_spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedUsage5h(); ok {
_spec.AddField(apikey.FieldUsage5h, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Usage1d(); ok {
_spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedUsage1d(); ok {
_spec.AddField(apikey.FieldUsage1d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Usage7d(); ok {
_spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedUsage7d(); ok {
_spec.AddField(apikey.FieldUsage7d, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Window5hStart(); ok {
_spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value)
}
if _u.mutation.Window5hStartCleared() {
_spec.ClearField(apikey.FieldWindow5hStart, field.TypeTime)
}
if value, ok := _u.mutation.Window1dStart(); ok {
_spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value)
}
if _u.mutation.Window1dStartCleared() {
_spec.ClearField(apikey.FieldWindow1dStart, field.TypeTime)
}
if value, ok := _u.mutation.Window7dStart(); ok {
_spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value)
}
if _u.mutation.Window7dStartCleared() {
_spec.ClearField(apikey.FieldWindow7dStart, field.TypeTime)
}
if _u.mutation.UserCleared() { if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,

View File

@@ -24,6 +24,15 @@ var (
{Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "expires_at", Type: field.TypeTime, Nullable: true}, {Name: "expires_at", Type: field.TypeTime, Nullable: true},
{Name: "rate_limit_5h", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "rate_limit_1d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "rate_limit_7d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "usage_5h", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "usage_1d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "usage_7d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "window_5h_start", Type: field.TypeTime, Nullable: true},
{Name: "window_1d_start", Type: field.TypeTime, Nullable: true},
{Name: "window_7d_start", Type: field.TypeTime, Nullable: true},
{Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true},
{Name: "user_id", Type: field.TypeInt64}, {Name: "user_id", Type: field.TypeInt64},
} }
@@ -35,13 +44,13 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "api_keys_groups_api_keys", Symbol: "api_keys_groups_api_keys",
Columns: []*schema.Column{APIKeysColumns[13]}, Columns: []*schema.Column{APIKeysColumns[22]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "api_keys_users_api_keys", Symbol: "api_keys_users_api_keys",
Columns: []*schema.Column{APIKeysColumns[14]}, Columns: []*schema.Column{APIKeysColumns[23]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
@@ -50,12 +59,12 @@ var (
{ {
Name: "apikey_user_id", Name: "apikey_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{APIKeysColumns[14]}, Columns: []*schema.Column{APIKeysColumns[23]},
}, },
{ {
Name: "apikey_group_id", Name: "apikey_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{APIKeysColumns[13]}, Columns: []*schema.Column{APIKeysColumns[22]},
}, },
{ {
Name: "apikey_status", Name: "apikey_status",

View File

@@ -91,6 +91,21 @@ type APIKeyMutation struct {
quota_used *float64 quota_used *float64
addquota_used *float64 addquota_used *float64
expires_at *time.Time expires_at *time.Time
rate_limit_5h *float64
addrate_limit_5h *float64
rate_limit_1d *float64
addrate_limit_1d *float64
rate_limit_7d *float64
addrate_limit_7d *float64
usage_5h *float64
addusage_5h *float64
usage_1d *float64
addusage_1d *float64
usage_7d *float64
addusage_7d *float64
window_5h_start *time.Time
window_1d_start *time.Time
window_7d_start *time.Time
clearedFields map[string]struct{} clearedFields map[string]struct{}
user *int64 user *int64
cleareduser bool cleareduser bool
@@ -856,6 +871,489 @@ func (m *APIKeyMutation) ResetExpiresAt() {
delete(m.clearedFields, apikey.FieldExpiresAt) delete(m.clearedFields, apikey.FieldExpiresAt)
} }
// SetRateLimit5h sets the "rate_limit_5h" field.
func (m *APIKeyMutation) SetRateLimit5h(f float64) {
m.rate_limit_5h = &f
m.addrate_limit_5h = nil
}
// RateLimit5h returns the value of the "rate_limit_5h" field in the mutation.
func (m *APIKeyMutation) RateLimit5h() (r float64, exists bool) {
v := m.rate_limit_5h
if v == nil {
return
}
return *v, true
}
// OldRateLimit5h returns the old "rate_limit_5h" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldRateLimit5h(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldRateLimit5h is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldRateLimit5h requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldRateLimit5h: %w", err)
}
return oldValue.RateLimit5h, nil
}
// AddRateLimit5h adds f to the "rate_limit_5h" field.
func (m *APIKeyMutation) AddRateLimit5h(f float64) {
if m.addrate_limit_5h != nil {
*m.addrate_limit_5h += f
} else {
m.addrate_limit_5h = &f
}
}
// AddedRateLimit5h returns the value that was added to the "rate_limit_5h" field in this mutation.
func (m *APIKeyMutation) AddedRateLimit5h() (r float64, exists bool) {
v := m.addrate_limit_5h
if v == nil {
return
}
return *v, true
}
// ResetRateLimit5h resets all changes to the "rate_limit_5h" field.
func (m *APIKeyMutation) ResetRateLimit5h() {
m.rate_limit_5h = nil
m.addrate_limit_5h = nil
}
// SetRateLimit1d sets the "rate_limit_1d" field.
func (m *APIKeyMutation) SetRateLimit1d(f float64) {
m.rate_limit_1d = &f
m.addrate_limit_1d = nil
}
// RateLimit1d returns the value of the "rate_limit_1d" field in the mutation.
func (m *APIKeyMutation) RateLimit1d() (r float64, exists bool) {
v := m.rate_limit_1d
if v == nil {
return
}
return *v, true
}
// OldRateLimit1d returns the old "rate_limit_1d" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldRateLimit1d(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldRateLimit1d is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldRateLimit1d requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldRateLimit1d: %w", err)
}
return oldValue.RateLimit1d, nil
}
// AddRateLimit1d adds f to the "rate_limit_1d" field.
func (m *APIKeyMutation) AddRateLimit1d(f float64) {
if m.addrate_limit_1d != nil {
*m.addrate_limit_1d += f
} else {
m.addrate_limit_1d = &f
}
}
// AddedRateLimit1d returns the value that was added to the "rate_limit_1d" field in this mutation.
func (m *APIKeyMutation) AddedRateLimit1d() (r float64, exists bool) {
v := m.addrate_limit_1d
if v == nil {
return
}
return *v, true
}
// ResetRateLimit1d resets all changes to the "rate_limit_1d" field.
func (m *APIKeyMutation) ResetRateLimit1d() {
m.rate_limit_1d = nil
m.addrate_limit_1d = nil
}
// SetRateLimit7d sets the "rate_limit_7d" field.
func (m *APIKeyMutation) SetRateLimit7d(f float64) {
m.rate_limit_7d = &f
m.addrate_limit_7d = nil
}
// RateLimit7d returns the value of the "rate_limit_7d" field in the mutation.
func (m *APIKeyMutation) RateLimit7d() (r float64, exists bool) {
v := m.rate_limit_7d
if v == nil {
return
}
return *v, true
}
// OldRateLimit7d returns the old "rate_limit_7d" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldRateLimit7d(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldRateLimit7d is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldRateLimit7d requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldRateLimit7d: %w", err)
}
return oldValue.RateLimit7d, nil
}
// AddRateLimit7d adds f to the "rate_limit_7d" field.
func (m *APIKeyMutation) AddRateLimit7d(f float64) {
if m.addrate_limit_7d != nil {
*m.addrate_limit_7d += f
} else {
m.addrate_limit_7d = &f
}
}
// AddedRateLimit7d returns the value that was added to the "rate_limit_7d" field in this mutation.
func (m *APIKeyMutation) AddedRateLimit7d() (r float64, exists bool) {
v := m.addrate_limit_7d
if v == nil {
return
}
return *v, true
}
// ResetRateLimit7d resets all changes to the "rate_limit_7d" field.
func (m *APIKeyMutation) ResetRateLimit7d() {
m.rate_limit_7d = nil
m.addrate_limit_7d = nil
}
// SetUsage5h sets the "usage_5h" field.
func (m *APIKeyMutation) SetUsage5h(f float64) {
m.usage_5h = &f
m.addusage_5h = nil
}
// Usage5h returns the value of the "usage_5h" field in the mutation.
func (m *APIKeyMutation) Usage5h() (r float64, exists bool) {
v := m.usage_5h
if v == nil {
return
}
return *v, true
}
// OldUsage5h returns the old "usage_5h" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldUsage5h(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUsage5h is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUsage5h requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUsage5h: %w", err)
}
return oldValue.Usage5h, nil
}
// AddUsage5h adds f to the "usage_5h" field.
func (m *APIKeyMutation) AddUsage5h(f float64) {
if m.addusage_5h != nil {
*m.addusage_5h += f
} else {
m.addusage_5h = &f
}
}
// AddedUsage5h returns the value that was added to the "usage_5h" field in this mutation.
func (m *APIKeyMutation) AddedUsage5h() (r float64, exists bool) {
v := m.addusage_5h
if v == nil {
return
}
return *v, true
}
// ResetUsage5h resets all changes to the "usage_5h" field.
func (m *APIKeyMutation) ResetUsage5h() {
m.usage_5h = nil
m.addusage_5h = nil
}
// SetUsage1d sets the "usage_1d" field.
func (m *APIKeyMutation) SetUsage1d(f float64) {
m.usage_1d = &f
m.addusage_1d = nil
}
// Usage1d returns the value of the "usage_1d" field in the mutation.
func (m *APIKeyMutation) Usage1d() (r float64, exists bool) {
v := m.usage_1d
if v == nil {
return
}
return *v, true
}
// OldUsage1d returns the old "usage_1d" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldUsage1d(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUsage1d is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUsage1d requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUsage1d: %w", err)
}
return oldValue.Usage1d, nil
}
// AddUsage1d adds f to the "usage_1d" field.
func (m *APIKeyMutation) AddUsage1d(f float64) {
if m.addusage_1d != nil {
*m.addusage_1d += f
} else {
m.addusage_1d = &f
}
}
// AddedUsage1d returns the value that was added to the "usage_1d" field in this mutation.
func (m *APIKeyMutation) AddedUsage1d() (r float64, exists bool) {
v := m.addusage_1d
if v == nil {
return
}
return *v, true
}
// ResetUsage1d resets all changes to the "usage_1d" field.
func (m *APIKeyMutation) ResetUsage1d() {
m.usage_1d = nil
m.addusage_1d = nil
}
// SetUsage7d sets the "usage_7d" field.
func (m *APIKeyMutation) SetUsage7d(f float64) {
m.usage_7d = &f
m.addusage_7d = nil
}
// Usage7d returns the value of the "usage_7d" field in the mutation.
func (m *APIKeyMutation) Usage7d() (r float64, exists bool) {
v := m.usage_7d
if v == nil {
return
}
return *v, true
}
// OldUsage7d returns the old "usage_7d" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldUsage7d(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUsage7d is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUsage7d requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUsage7d: %w", err)
}
return oldValue.Usage7d, nil
}
// AddUsage7d adds f to the "usage_7d" field.
func (m *APIKeyMutation) AddUsage7d(f float64) {
if m.addusage_7d != nil {
*m.addusage_7d += f
} else {
m.addusage_7d = &f
}
}
// AddedUsage7d returns the value that was added to the "usage_7d" field in this mutation.
func (m *APIKeyMutation) AddedUsage7d() (r float64, exists bool) {
v := m.addusage_7d
if v == nil {
return
}
return *v, true
}
// ResetUsage7d resets all changes to the "usage_7d" field.
func (m *APIKeyMutation) ResetUsage7d() {
m.usage_7d = nil
m.addusage_7d = nil
}
// SetWindow5hStart sets the "window_5h_start" field.
func (m *APIKeyMutation) SetWindow5hStart(t time.Time) {
m.window_5h_start = &t
}
// Window5hStart returns the value of the "window_5h_start" field in the mutation.
func (m *APIKeyMutation) Window5hStart() (r time.Time, exists bool) {
v := m.window_5h_start
if v == nil {
return
}
return *v, true
}
// OldWindow5hStart returns the old "window_5h_start" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldWindow5hStart(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldWindow5hStart is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldWindow5hStart requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldWindow5hStart: %w", err)
}
return oldValue.Window5hStart, nil
}
// ClearWindow5hStart clears the value of the "window_5h_start" field.
func (m *APIKeyMutation) ClearWindow5hStart() {
m.window_5h_start = nil
m.clearedFields[apikey.FieldWindow5hStart] = struct{}{}
}
// Window5hStartCleared returns if the "window_5h_start" field was cleared in this mutation.
func (m *APIKeyMutation) Window5hStartCleared() bool {
_, ok := m.clearedFields[apikey.FieldWindow5hStart]
return ok
}
// ResetWindow5hStart resets all changes to the "window_5h_start" field.
func (m *APIKeyMutation) ResetWindow5hStart() {
m.window_5h_start = nil
delete(m.clearedFields, apikey.FieldWindow5hStart)
}
// SetWindow1dStart sets the "window_1d_start" field.
func (m *APIKeyMutation) SetWindow1dStart(t time.Time) {
m.window_1d_start = &t
}
// Window1dStart returns the value of the "window_1d_start" field in the mutation.
func (m *APIKeyMutation) Window1dStart() (r time.Time, exists bool) {
v := m.window_1d_start
if v == nil {
return
}
return *v, true
}
// OldWindow1dStart returns the old "window_1d_start" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldWindow1dStart(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldWindow1dStart is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldWindow1dStart requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldWindow1dStart: %w", err)
}
return oldValue.Window1dStart, nil
}
// ClearWindow1dStart clears the value of the "window_1d_start" field.
func (m *APIKeyMutation) ClearWindow1dStart() {
m.window_1d_start = nil
m.clearedFields[apikey.FieldWindow1dStart] = struct{}{}
}
// Window1dStartCleared returns if the "window_1d_start" field was cleared in this mutation.
func (m *APIKeyMutation) Window1dStartCleared() bool {
_, ok := m.clearedFields[apikey.FieldWindow1dStart]
return ok
}
// ResetWindow1dStart resets all changes to the "window_1d_start" field.
func (m *APIKeyMutation) ResetWindow1dStart() {
m.window_1d_start = nil
delete(m.clearedFields, apikey.FieldWindow1dStart)
}
// SetWindow7dStart sets the "window_7d_start" field.
func (m *APIKeyMutation) SetWindow7dStart(t time.Time) {
m.window_7d_start = &t
}
// Window7dStart returns the value of the "window_7d_start" field in the mutation.
func (m *APIKeyMutation) Window7dStart() (r time.Time, exists bool) {
v := m.window_7d_start
if v == nil {
return
}
return *v, true
}
// OldWindow7dStart returns the old "window_7d_start" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldWindow7dStart(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldWindow7dStart is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldWindow7dStart requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldWindow7dStart: %w", err)
}
return oldValue.Window7dStart, nil
}
// ClearWindow7dStart clears the value of the "window_7d_start" field.
func (m *APIKeyMutation) ClearWindow7dStart() {
m.window_7d_start = nil
m.clearedFields[apikey.FieldWindow7dStart] = struct{}{}
}
// Window7dStartCleared returns if the "window_7d_start" field was cleared in this mutation.
func (m *APIKeyMutation) Window7dStartCleared() bool {
_, ok := m.clearedFields[apikey.FieldWindow7dStart]
return ok
}
// ResetWindow7dStart resets all changes to the "window_7d_start" field.
func (m *APIKeyMutation) ResetWindow7dStart() {
m.window_7d_start = nil
delete(m.clearedFields, apikey.FieldWindow7dStart)
}
// ClearUser clears the "user" edge to the User entity. // ClearUser clears the "user" edge to the User entity.
func (m *APIKeyMutation) ClearUser() { func (m *APIKeyMutation) ClearUser() {
m.cleareduser = true m.cleareduser = true
@@ -998,7 +1496,7 @@ func (m *APIKeyMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *APIKeyMutation) Fields() []string { func (m *APIKeyMutation) Fields() []string {
fields := make([]string, 0, 14) fields := make([]string, 0, 23)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, apikey.FieldCreatedAt) fields = append(fields, apikey.FieldCreatedAt)
} }
@@ -1041,6 +1539,33 @@ func (m *APIKeyMutation) Fields() []string {
if m.expires_at != nil { if m.expires_at != nil {
fields = append(fields, apikey.FieldExpiresAt) fields = append(fields, apikey.FieldExpiresAt)
} }
if m.rate_limit_5h != nil {
fields = append(fields, apikey.FieldRateLimit5h)
}
if m.rate_limit_1d != nil {
fields = append(fields, apikey.FieldRateLimit1d)
}
if m.rate_limit_7d != nil {
fields = append(fields, apikey.FieldRateLimit7d)
}
if m.usage_5h != nil {
fields = append(fields, apikey.FieldUsage5h)
}
if m.usage_1d != nil {
fields = append(fields, apikey.FieldUsage1d)
}
if m.usage_7d != nil {
fields = append(fields, apikey.FieldUsage7d)
}
if m.window_5h_start != nil {
fields = append(fields, apikey.FieldWindow5hStart)
}
if m.window_1d_start != nil {
fields = append(fields, apikey.FieldWindow1dStart)
}
if m.window_7d_start != nil {
fields = append(fields, apikey.FieldWindow7dStart)
}
return fields return fields
} }
@@ -1077,6 +1602,24 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) {
return m.QuotaUsed() return m.QuotaUsed()
case apikey.FieldExpiresAt: case apikey.FieldExpiresAt:
return m.ExpiresAt() return m.ExpiresAt()
case apikey.FieldRateLimit5h:
return m.RateLimit5h()
case apikey.FieldRateLimit1d:
return m.RateLimit1d()
case apikey.FieldRateLimit7d:
return m.RateLimit7d()
case apikey.FieldUsage5h:
return m.Usage5h()
case apikey.FieldUsage1d:
return m.Usage1d()
case apikey.FieldUsage7d:
return m.Usage7d()
case apikey.FieldWindow5hStart:
return m.Window5hStart()
case apikey.FieldWindow1dStart:
return m.Window1dStart()
case apikey.FieldWindow7dStart:
return m.Window7dStart()
} }
return nil, false return nil, false
} }
@@ -1114,6 +1657,24 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldQuotaUsed(ctx) return m.OldQuotaUsed(ctx)
case apikey.FieldExpiresAt: case apikey.FieldExpiresAt:
return m.OldExpiresAt(ctx) return m.OldExpiresAt(ctx)
case apikey.FieldRateLimit5h:
return m.OldRateLimit5h(ctx)
case apikey.FieldRateLimit1d:
return m.OldRateLimit1d(ctx)
case apikey.FieldRateLimit7d:
return m.OldRateLimit7d(ctx)
case apikey.FieldUsage5h:
return m.OldUsage5h(ctx)
case apikey.FieldUsage1d:
return m.OldUsage1d(ctx)
case apikey.FieldUsage7d:
return m.OldUsage7d(ctx)
case apikey.FieldWindow5hStart:
return m.OldWindow5hStart(ctx)
case apikey.FieldWindow1dStart:
return m.OldWindow1dStart(ctx)
case apikey.FieldWindow7dStart:
return m.OldWindow7dStart(ctx)
} }
return nil, fmt.Errorf("unknown APIKey field %s", name) return nil, fmt.Errorf("unknown APIKey field %s", name)
} }
@@ -1221,6 +1782,69 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error {
} }
m.SetExpiresAt(v) m.SetExpiresAt(v)
return nil return nil
case apikey.FieldRateLimit5h:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetRateLimit5h(v)
return nil
case apikey.FieldRateLimit1d:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetRateLimit1d(v)
return nil
case apikey.FieldRateLimit7d:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetRateLimit7d(v)
return nil
case apikey.FieldUsage5h:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUsage5h(v)
return nil
case apikey.FieldUsage1d:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUsage1d(v)
return nil
case apikey.FieldUsage7d:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUsage7d(v)
return nil
case apikey.FieldWindow5hStart:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetWindow5hStart(v)
return nil
case apikey.FieldWindow1dStart:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetWindow1dStart(v)
return nil
case apikey.FieldWindow7dStart:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetWindow7dStart(v)
return nil
} }
return fmt.Errorf("unknown APIKey field %s", name) return fmt.Errorf("unknown APIKey field %s", name)
} }
@@ -1235,6 +1859,24 @@ func (m *APIKeyMutation) AddedFields() []string {
if m.addquota_used != nil { if m.addquota_used != nil {
fields = append(fields, apikey.FieldQuotaUsed) fields = append(fields, apikey.FieldQuotaUsed)
} }
if m.addrate_limit_5h != nil {
fields = append(fields, apikey.FieldRateLimit5h)
}
if m.addrate_limit_1d != nil {
fields = append(fields, apikey.FieldRateLimit1d)
}
if m.addrate_limit_7d != nil {
fields = append(fields, apikey.FieldRateLimit7d)
}
if m.addusage_5h != nil {
fields = append(fields, apikey.FieldUsage5h)
}
if m.addusage_1d != nil {
fields = append(fields, apikey.FieldUsage1d)
}
if m.addusage_7d != nil {
fields = append(fields, apikey.FieldUsage7d)
}
return fields return fields
} }
@@ -1247,6 +1889,18 @@ func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedQuota() return m.AddedQuota()
case apikey.FieldQuotaUsed: case apikey.FieldQuotaUsed:
return m.AddedQuotaUsed() return m.AddedQuotaUsed()
case apikey.FieldRateLimit5h:
return m.AddedRateLimit5h()
case apikey.FieldRateLimit1d:
return m.AddedRateLimit1d()
case apikey.FieldRateLimit7d:
return m.AddedRateLimit7d()
case apikey.FieldUsage5h:
return m.AddedUsage5h()
case apikey.FieldUsage1d:
return m.AddedUsage1d()
case apikey.FieldUsage7d:
return m.AddedUsage7d()
} }
return nil, false return nil, false
} }
@@ -1270,6 +1924,48 @@ func (m *APIKeyMutation) AddField(name string, value ent.Value) error {
} }
m.AddQuotaUsed(v) m.AddQuotaUsed(v)
return nil return nil
case apikey.FieldRateLimit5h:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddRateLimit5h(v)
return nil
case apikey.FieldRateLimit1d:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddRateLimit1d(v)
return nil
case apikey.FieldRateLimit7d:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddRateLimit7d(v)
return nil
case apikey.FieldUsage5h:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddUsage5h(v)
return nil
case apikey.FieldUsage1d:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddUsage1d(v)
return nil
case apikey.FieldUsage7d:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddUsage7d(v)
return nil
} }
return fmt.Errorf("unknown APIKey numeric field %s", name) return fmt.Errorf("unknown APIKey numeric field %s", name)
} }
@@ -1296,6 +1992,15 @@ func (m *APIKeyMutation) ClearedFields() []string {
if m.FieldCleared(apikey.FieldExpiresAt) { if m.FieldCleared(apikey.FieldExpiresAt) {
fields = append(fields, apikey.FieldExpiresAt) fields = append(fields, apikey.FieldExpiresAt)
} }
if m.FieldCleared(apikey.FieldWindow5hStart) {
fields = append(fields, apikey.FieldWindow5hStart)
}
if m.FieldCleared(apikey.FieldWindow1dStart) {
fields = append(fields, apikey.FieldWindow1dStart)
}
if m.FieldCleared(apikey.FieldWindow7dStart) {
fields = append(fields, apikey.FieldWindow7dStart)
}
return fields return fields
} }
@@ -1328,6 +2033,15 @@ func (m *APIKeyMutation) ClearField(name string) error {
case apikey.FieldExpiresAt: case apikey.FieldExpiresAt:
m.ClearExpiresAt() m.ClearExpiresAt()
return nil return nil
case apikey.FieldWindow5hStart:
m.ClearWindow5hStart()
return nil
case apikey.FieldWindow1dStart:
m.ClearWindow1dStart()
return nil
case apikey.FieldWindow7dStart:
m.ClearWindow7dStart()
return nil
} }
return fmt.Errorf("unknown APIKey nullable field %s", name) return fmt.Errorf("unknown APIKey nullable field %s", name)
} }
@@ -1378,6 +2092,33 @@ func (m *APIKeyMutation) ResetField(name string) error {
case apikey.FieldExpiresAt: case apikey.FieldExpiresAt:
m.ResetExpiresAt() m.ResetExpiresAt()
return nil return nil
case apikey.FieldRateLimit5h:
m.ResetRateLimit5h()
return nil
case apikey.FieldRateLimit1d:
m.ResetRateLimit1d()
return nil
case apikey.FieldRateLimit7d:
m.ResetRateLimit7d()
return nil
case apikey.FieldUsage5h:
m.ResetUsage5h()
return nil
case apikey.FieldUsage1d:
m.ResetUsage1d()
return nil
case apikey.FieldUsage7d:
m.ResetUsage7d()
return nil
case apikey.FieldWindow5hStart:
m.ResetWindow5hStart()
return nil
case apikey.FieldWindow1dStart:
m.ResetWindow1dStart()
return nil
case apikey.FieldWindow7dStart:
m.ResetWindow7dStart()
return nil
} }
return fmt.Errorf("unknown APIKey field %s", name) return fmt.Errorf("unknown APIKey field %s", name)
} }

View File

@@ -102,6 +102,30 @@ func init() {
apikeyDescQuotaUsed := apikeyFields[9].Descriptor() apikeyDescQuotaUsed := apikeyFields[9].Descriptor()
// apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field. // apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field.
apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64) apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64)
// apikeyDescRateLimit5h is the schema descriptor for rate_limit_5h field.
apikeyDescRateLimit5h := apikeyFields[11].Descriptor()
// apikey.DefaultRateLimit5h holds the default value on creation for the rate_limit_5h field.
apikey.DefaultRateLimit5h = apikeyDescRateLimit5h.Default.(float64)
// apikeyDescRateLimit1d is the schema descriptor for rate_limit_1d field.
apikeyDescRateLimit1d := apikeyFields[12].Descriptor()
// apikey.DefaultRateLimit1d holds the default value on creation for the rate_limit_1d field.
apikey.DefaultRateLimit1d = apikeyDescRateLimit1d.Default.(float64)
// apikeyDescRateLimit7d is the schema descriptor for rate_limit_7d field.
apikeyDescRateLimit7d := apikeyFields[13].Descriptor()
// apikey.DefaultRateLimit7d holds the default value on creation for the rate_limit_7d field.
apikey.DefaultRateLimit7d = apikeyDescRateLimit7d.Default.(float64)
// apikeyDescUsage5h is the schema descriptor for usage_5h field.
apikeyDescUsage5h := apikeyFields[14].Descriptor()
// apikey.DefaultUsage5h holds the default value on creation for the usage_5h field.
apikey.DefaultUsage5h = apikeyDescUsage5h.Default.(float64)
// apikeyDescUsage1d is the schema descriptor for usage_1d field.
apikeyDescUsage1d := apikeyFields[15].Descriptor()
// apikey.DefaultUsage1d holds the default value on creation for the usage_1d field.
apikey.DefaultUsage1d = apikeyDescUsage1d.Default.(float64)
// apikeyDescUsage7d is the schema descriptor for usage_7d field.
apikeyDescUsage7d := apikeyFields[16].Descriptor()
// apikey.DefaultUsage7d holds the default value on creation for the usage_7d field.
apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64)
accountMixin := schema.Account{}.Mixin() accountMixin := schema.Account{}.Mixin()
accountMixinHooks1 := accountMixin[1].Hooks() accountMixinHooks1 := accountMixin[1].Hooks()
account.Hooks[0] = accountMixinHooks1[0] account.Hooks[0] = accountMixinHooks1[0]

View File

@@ -74,6 +74,47 @@ func (APIKey) Fields() []ent.Field {
Optional(). Optional().
Nillable(). Nillable().
Comment("Expiration time for this API key (null = never expires)"), Comment("Expiration time for this API key (null = never expires)"),
// ========== Rate limit fields ==========
// Rate limit configuration (0 = unlimited)
field.Float("rate_limit_5h").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0).
Comment("Rate limit in USD per 5 hours (0 = unlimited)"),
field.Float("rate_limit_1d").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0).
Comment("Rate limit in USD per day (0 = unlimited)"),
field.Float("rate_limit_7d").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0).
Comment("Rate limit in USD per 7 days (0 = unlimited)"),
// Rate limit usage tracking
field.Float("usage_5h").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0).
Comment("Used amount in USD for the current 5h window"),
field.Float("usage_1d").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0).
Comment("Used amount in USD for the current 1d window"),
field.Float("usage_7d").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0).
Comment("Used amount in USD for the current 7d window"),
// Window start times
field.Time("window_5h_start").
Optional().
Nillable().
Comment("Start time of the current 5h rate limit window"),
field.Time("window_1d_start").
Optional().
Nillable().
Comment("Start time of the current 1d rate limit window"),
field.Time("window_7d_start").
Optional().
Nillable().
Comment("Start time of the current 7d rate limit window"),
} }
} }

View File

@@ -180,8 +180,6 @@ require (
golang.org/x/text v0.34.0 // indirect golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect golang.org/x/tools v0.41.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect

View File

@@ -872,7 +872,8 @@ type DefaultConfig struct {
} }
type RateLimitConfig struct { type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401临时不可调度冷却(分钟)
} }
// APIKeyAuthCacheConfig API Key 认证缓存配置 // APIKeyAuthCacheConfig API Key 认证缓存配置
@@ -1260,6 +1261,7 @@ func setDefaults() {
// RateLimit // RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit避免分支漂移 // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit避免分支漂移
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json") viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")

View File

@@ -123,6 +123,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsQueryModeDefault: settings.OpsQueryModeDefault, OpsQueryModeDefault: settings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion, MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
}) })
} }
@@ -193,6 +194,9 @@ type UpdateSettingsRequest struct {
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"` OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
MinClaudeCodeVersion string `json:"min_claude_code_version"` MinClaudeCodeVersion string `json:"min_claude_code_version"`
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
} }
// UpdateSettings 更新系统设置 // UpdateSettings 更新系统设置
@@ -465,6 +469,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableIdentityPatch: req.EnableIdentityPatch, EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt, IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion, MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
OpsMonitoringEnabled: func() bool { OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil { if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled return *req.OpsMonitoringEnabled
@@ -561,6 +566,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
}) })
} }
@@ -709,6 +715,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion { if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion {
changed = append(changed, "min_claude_code_version") changed = append(changed, "min_claude_code_version")
} }
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
changed = append(changed, "allow_ungrouped_key_scheduling")
}
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled { if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
changed = append(changed, "purchase_subscription_enabled") changed = append(changed, "purchase_subscription_enabled")
} }

View File

@@ -36,6 +36,11 @@ type CreateAPIKeyRequest struct {
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
Quota *float64 `json:"quota"` // 配额限制 (USD) Quota *float64 `json:"quota"` // 配额限制 (USD)
ExpiresInDays *int `json:"expires_in_days"` // 过期天数 ExpiresInDays *int `json:"expires_in_days"` // 过期天数
// Rate limit fields (0 = unlimited)
RateLimit5h *float64 `json:"rate_limit_5h"`
RateLimit1d *float64 `json:"rate_limit_1d"`
RateLimit7d *float64 `json:"rate_limit_7d"`
} }
// UpdateAPIKeyRequest represents the update API key request payload // UpdateAPIKeyRequest represents the update API key request payload
@@ -48,6 +53,12 @@ type UpdateAPIKeyRequest struct {
Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制 Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制
ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601) ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601)
ResetQuota *bool `json:"reset_quota"` // 重置已用配额 ResetQuota *bool `json:"reset_quota"` // 重置已用配额
// Rate limit fields (nil = no change, 0 = unlimited)
RateLimit5h *float64 `json:"rate_limit_5h"`
RateLimit1d *float64 `json:"rate_limit_1d"`
RateLimit7d *float64 `json:"rate_limit_7d"`
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // 重置限速用量
} }
// List handles listing user's API keys with pagination // List handles listing user's API keys with pagination
@@ -131,6 +142,15 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
if req.Quota != nil { if req.Quota != nil {
svcReq.Quota = *req.Quota svcReq.Quota = *req.Quota
} }
if req.RateLimit5h != nil {
svcReq.RateLimit5h = *req.RateLimit5h
}
if req.RateLimit1d != nil {
svcReq.RateLimit1d = *req.RateLimit1d
}
if req.RateLimit7d != nil {
svcReq.RateLimit7d = *req.RateLimit7d
}
executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq) key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq)
@@ -163,10 +183,14 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
} }
svcReq := service.UpdateAPIKeyRequest{ svcReq := service.UpdateAPIKeyRequest{
IPWhitelist: req.IPWhitelist, IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist, IPBlacklist: req.IPBlacklist,
Quota: req.Quota, Quota: req.Quota,
ResetQuota: req.ResetQuota, ResetQuota: req.ResetQuota,
RateLimit5h: req.RateLimit5h,
RateLimit1d: req.RateLimit1d,
RateLimit7d: req.RateLimit7d,
ResetRateLimitUsage: req.ResetRateLimitUsage,
} }
if req.Name != "" { if req.Name != "" {
svcReq.Name = &req.Name svcReq.Name = &req.Name

View File

@@ -72,22 +72,31 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
return nil return nil
} }
return &APIKey{ return &APIKey{
ID: k.ID, ID: k.ID,
UserID: k.UserID, UserID: k.UserID,
Key: k.Key, Key: k.Key,
Name: k.Name, Name: k.Name,
GroupID: k.GroupID, GroupID: k.GroupID,
Status: k.Status, Status: k.Status,
IPWhitelist: k.IPWhitelist, IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist, IPBlacklist: k.IPBlacklist,
LastUsedAt: k.LastUsedAt, LastUsedAt: k.LastUsedAt,
Quota: k.Quota, Quota: k.Quota,
QuotaUsed: k.QuotaUsed, QuotaUsed: k.QuotaUsed,
ExpiresAt: k.ExpiresAt, ExpiresAt: k.ExpiresAt,
CreatedAt: k.CreatedAt, CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt, UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User), RateLimit5h: k.RateLimit5h,
Group: GroupFromServiceShallow(k.Group), RateLimit1d: k.RateLimit1d,
RateLimit7d: k.RateLimit7d,
Usage5h: k.Usage5h,
Usage1d: k.Usage1d,
Usage7d: k.Usage7d,
Window5hStart: k.Window5hStart,
Window1dStart: k.Window1dStart,
Window7dStart: k.Window7dStart,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
} }
} }

View File

@@ -77,6 +77,9 @@ type SystemSettings struct {
OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"` OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"`
MinClaudeCodeVersion string `json:"min_claude_code_version"` MinClaudeCodeVersion string `json:"min_claude_code_version"`
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {

View File

@@ -47,6 +47,17 @@ type APIKey struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
// Rate limit fields
RateLimit5h float64 `json:"rate_limit_5h"`
RateLimit1d float64 `json:"rate_limit_1d"`
RateLimit7d float64 `json:"rate_limit_7d"`
Usage5h float64 `json:"usage_5h"`
Usage1d float64 `json:"usage_1d"`
Usage7d float64 `json:"usage_7d"`
Window5hStart *time.Time `json:"window_5h_start"`
Window1dStart *time.Time `json:"window_1d_start"`
Window7dStart *time.Time `json:"window_7d_start"`
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"` Group *Group `json:"group,omitempty"`
} }

View File

@@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -844,6 +845,10 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service
// Usage handles getting account balance and usage statistics for CC Switch integration // Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage // GET /v1/usage
//
// Two modes:
// - quota_limited: API Key has quota or rate limits configured. Returns key-level limits/usage.
// - unrestricted: No key-level limits. Returns subscription or wallet balance info.
func (h *GatewayHandler) Usage(c *gin.Context) { func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c) apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok { if !ok {
@@ -857,54 +862,183 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return return
} }
ctx := c.Request.Context()
// 解析可选的日期范围参数(用于 model_stats 查询)
startTime, endTime := h.parseUsageDateRange(c)
// Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应 // Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
var usageData gin.H usageData := h.buildUsageData(ctx, apiKey.ID)
// Best-effort: 获取模型统计
var modelStats any
if h.usageService != nil { if h.usageService != nil {
dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID) if stats, err := h.usageService.GetAPIKeyModelStats(ctx, apiKey.ID, startTime, endTime); err == nil && len(stats) > 0 {
if err == nil && dashStats != nil { modelStats = stats
usageData = gin.H{ }
"today": gin.H{ }
"requests": dashStats.TodayRequests,
"input_tokens": dashStats.TodayInputTokens, // 判断模式: key 有总额度或速率限制 → quota_limited否则 → unrestricted
"output_tokens": dashStats.TodayOutputTokens, isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits()
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
"cache_read_tokens": dashStats.TodayCacheReadTokens, if isQuotaLimited {
"total_tokens": dashStats.TodayTokens, h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats)
"cost": dashStats.TodayCost, return
"actual_cost": dashStats.TodayActualCost, }
},
"total": gin.H{ h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats)
"requests": dashStats.TotalRequests, }
"input_tokens": dashStats.TotalInputTokens,
"output_tokens": dashStats.TotalOutputTokens, // parseUsageDateRange 解析 start_date / end_date query params默认返回近 30 天范围
"cache_creation_tokens": dashStats.TotalCacheCreationTokens, func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Time) {
"cache_read_tokens": dashStats.TotalCacheReadTokens, now := timezone.Now()
"total_tokens": dashStats.TotalTokens, endTime := now
"cost": dashStats.TotalCost, startTime := now.AddDate(0, 0, -30)
"actual_cost": dashStats.TotalActualCost,
}, if s := c.Query("start_date"); s != "" {
"average_duration_ms": dashStats.AverageDurationMs, if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
"rpm": dashStats.Rpm, startTime = t
"tpm": dashStats.Tpm, }
}
if s := c.Query("end_date"); s != "" {
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
endTime = t.Add(24*time.Hour - time.Second) // end of day
}
}
return startTime, endTime
}
// buildUsageData 构建 today/total 用量摘要
func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin.H {
if h.usageService == nil {
return nil
}
dashStats, err := h.usageService.GetAPIKeyDashboardStats(ctx, apiKeyID)
if err != nil || dashStats == nil {
return nil
}
return gin.H{
"today": gin.H{
"requests": dashStats.TodayRequests,
"input_tokens": dashStats.TodayInputTokens,
"output_tokens": dashStats.TodayOutputTokens,
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
"cache_read_tokens": dashStats.TodayCacheReadTokens,
"total_tokens": dashStats.TodayTokens,
"cost": dashStats.TodayCost,
"actual_cost": dashStats.TodayActualCost,
},
"total": gin.H{
"requests": dashStats.TotalRequests,
"input_tokens": dashStats.TotalInputTokens,
"output_tokens": dashStats.TotalOutputTokens,
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
"cache_read_tokens": dashStats.TotalCacheReadTokens,
"total_tokens": dashStats.TotalTokens,
"cost": dashStats.TotalCost,
"actual_cost": dashStats.TotalActualCost,
},
"average_duration_ms": dashStats.AverageDurationMs,
"rpm": dashStats.Rpm,
"tpm": dashStats.Tpm,
}
}
// usageQuotaLimited 处理 quota_limited 模式的响应
func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) {
resp := gin.H{
"mode": "quota_limited",
"isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired,
"status": apiKey.Status,
}
// 总额度信息
if apiKey.Quota > 0 {
remaining := apiKey.GetQuotaRemaining()
resp["quota"] = gin.H{
"limit": apiKey.Quota,
"used": apiKey.QuotaUsed,
"remaining": remaining,
"unit": "USD",
}
resp["remaining"] = remaining
resp["unit"] = "USD"
}
// 速率限制信息(从 DB 获取实时用量)
if apiKey.HasRateLimits() && h.apiKeyService != nil {
rateLimitData, err := h.apiKeyService.GetRateLimitData(ctx, apiKey.ID)
if err == nil && rateLimitData != nil {
var rateLimits []gin.H
if apiKey.RateLimit5h > 0 {
used := rateLimitData.Usage5h
rateLimits = append(rateLimits, gin.H{
"window": "5h",
"limit": apiKey.RateLimit5h,
"used": used,
"remaining": max(0, apiKey.RateLimit5h-used),
"window_start": rateLimitData.Window5hStart,
})
}
if apiKey.RateLimit1d > 0 {
used := rateLimitData.Usage1d
rateLimits = append(rateLimits, gin.H{
"window": "1d",
"limit": apiKey.RateLimit1d,
"used": used,
"remaining": max(0, apiKey.RateLimit1d-used),
"window_start": rateLimitData.Window1dStart,
})
}
if apiKey.RateLimit7d > 0 {
used := rateLimitData.Usage7d
rateLimits = append(rateLimits, gin.H{
"window": "7d",
"limit": apiKey.RateLimit7d,
"used": used,
"remaining": max(0, apiKey.RateLimit7d-used),
"window_start": rateLimitData.Window7dStart,
})
}
if len(rateLimits) > 0 {
resp["rate_limits"] = rateLimits
} }
} }
} }
// 订阅模式:返回订阅限额信息 + 用量统计 // 过期时间
if apiKey.ExpiresAt != nil {
resp["expires_at"] = apiKey.ExpiresAt
resp["days_until_expiry"] = apiKey.GetDaysUntilExpiry()
}
if usageData != nil {
resp["usage"] = usageData
}
if modelStats != nil {
resp["model_stats"] = modelStats
}
c.JSON(http.StatusOK, resp)
}
// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容)
func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) {
// 订阅模式
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware2.GetSubscriptionFromContext(c) resp := gin.H{
if !ok { "mode": "unrestricted",
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription") "isValid": true,
return "planName": apiKey.Group.Name,
"unit": "USD",
} }
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription) // 订阅信息可能不在 context 中(/v1/usage 路径跳过了中间件的计费检查)
resp := gin.H{ subscription, ok := middleware2.GetSubscriptionFromContext(c)
"isValid": true, if ok {
"planName": apiKey.Group.Name, remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
"remaining": remaining, resp["remaining"] = remaining
"unit": "USD", resp["subscription"] = gin.H{
"subscription": gin.H{
"daily_usage_usd": subscription.DailyUsageUSD, "daily_usage_usd": subscription.DailyUsageUSD,
"weekly_usage_usd": subscription.WeeklyUsageUSD, "weekly_usage_usd": subscription.WeeklyUsageUSD,
"monthly_usage_usd": subscription.MonthlyUsageUSD, "monthly_usage_usd": subscription.MonthlyUsageUSD,
@@ -912,23 +1046,28 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
"weekly_limit_usd": apiKey.Group.WeeklyLimitUSD, "weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
"monthly_limit_usd": apiKey.Group.MonthlyLimitUSD, "monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
"expires_at": subscription.ExpiresAt, "expires_at": subscription.ExpiresAt,
}, }
} }
if usageData != nil { if usageData != nil {
resp["usage"] = usageData resp["usage"] = usageData
} }
if modelStats != nil {
resp["model_stats"] = modelStats
}
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
return return
} }
// 余额模式:返回钱包余额 + 用量统计 // 余额模式
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) latestUser, err := h.userService.GetByID(ctx, subject.UserID)
if err != nil { if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info") h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return return
} }
resp := gin.H{ resp := gin.H{
"mode": "unrestricted",
"isValid": true, "isValid": true,
"planName": "钱包余额", "planName": "钱包余额",
"remaining": latestUser.Balance, "remaining": latestUser.Balance,
@@ -938,6 +1077,9 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
if usageData != nil { if usageData != nil {
resp["usage"] = usageData resp["usage"] = usageData
} }
if modelStats != nil {
resp["model_stats"] = modelStats
}
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
@@ -1445,6 +1587,18 @@ func billingErrorDetails(err error) (status int, code, message string) {
} }
return http.StatusServiceUnavailable, "billing_service_error", msg return http.StatusServiceUnavailable, "billing_service_error", msg
} }
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
}
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
}
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
}
msg := pkgerrors.Message(err) msg := pkgerrors.Message(err)
if msg == "" { if msg == "" {
logger.L().With( logger.L().With(

View File

@@ -159,7 +159,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。 // RunModeSimple跳过计费检查避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg) billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)

View File

@@ -1032,6 +1032,15 @@ func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64
func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error { func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error {
return nil return nil
} }
func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error {
return nil
}
func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error {
return nil
}
func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) {
return nil, nil
}
// newTestAPIKeyService 创建测试用的 APIKeyService // newTestAPIKeyService 创建测试用的 APIKeyService
func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService { func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService {
@@ -2089,6 +2098,12 @@ func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context,
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
return r.accounts, nil return r.accounts, nil
} }
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
return nil return nil
} }

View File

@@ -182,6 +182,12 @@ func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platfo
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms) return r.ListSchedulableByPlatforms(ctx, platforms)
} }
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }
@@ -405,7 +411,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
deferredService := service.NewDeferredService(accountRepo, nil, 0) deferredService := service.NewDeferredService(accountRepo, nil, 0)
billingService := service.NewBillingService(cfg, nil) billingService := service.NewBillingService(cfg, nil)
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{}) concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg) billingCacheService := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
t.Cleanup(func() { t.Cleanup(func() {
billingCacheService.Stop() billingCacheService.Stop()
}) })

View File

@@ -53,8 +53,7 @@ const (
var defaultUserAgentVersion = "1.19.6" var defaultUserAgentVersion = "1.19.6"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 // defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
// 默认值使用占位符,生产环境请通过环境变量注入真实值。 var defaultClientSecret = "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"
var defaultClientSecret = "GOCSPX-your-client-secret"
func init() { func init() {
// 从环境变量读取版本号,未设置则使用默认值 // 从环境变量读取版本号,未设置则使用默认值

View File

@@ -684,7 +684,7 @@ func TestConstants_值正确(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err) t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
} }
if secret != "GOCSPX-your-client-secret" { if secret != "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" {
t.Errorf("默认 client_secret 不匹配: got %s", secret) t.Errorf("默认 client_secret 不匹配: got %s", secret)
} }
if RedirectURI != "http://localhost:8085/callback" { if RedirectURI != "http://localhost:8085/callback" {

View File

@@ -39,7 +39,7 @@ const (
// They enable the "login without creating your own OAuth client" experience, but Google may // They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client. // restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret" GeminiCLIOAuthClientSecret = "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"
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret. // GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET" GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"

View File

@@ -829,6 +829,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
return r.accountsToService(ctx, accounts) return r.accountsToService(ctx, accounts)
} }
func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Not(dbaccount.HasAccountGroups()),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformIn(platforms...),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Not(dbaccount.HasAccountGroups()),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 { if len(platforms) == 0 {
return nil, nil return nil, nil

View File

@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
userRepo := newUserRepositoryWithSQL(entClient, tx) userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx) groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewAPIKeyRepository(entClient) apiKeyRepo := newAPIKeyRepositoryWithSQL(entClient, tx)
u := &service.User{ u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com", Email: uniqueTestValue(t, "cascade-user") + "@example.com",

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"database/sql"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -16,10 +17,15 @@ import (
type apiKeyRepository struct { type apiKeyRepository struct {
client *dbent.Client client *dbent.Client
sql sqlExecutor
} }
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository { func NewAPIKeyRepository(client *dbent.Client, sqlDB *sql.DB) service.APIKeyRepository {
return &apiKeyRepository{client: client} return newAPIKeyRepositoryWithSQL(client, sqlDB)
}
func newAPIKeyRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *apiKeyRepository {
return &apiKeyRepository{client: client, sql: sqlq}
} }
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery { func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
@@ -37,7 +43,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
SetNillableLastUsedAt(key.LastUsedAt). SetNillableLastUsedAt(key.LastUsedAt).
SetQuota(key.Quota). SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed). SetQuotaUsed(key.QuotaUsed).
SetNillableExpiresAt(key.ExpiresAt) SetNillableExpiresAt(key.ExpiresAt).
SetRateLimit5h(key.RateLimit5h).
SetRateLimit1d(key.RateLimit1d).
SetRateLimit7d(key.RateLimit7d)
if len(key.IPWhitelist) > 0 { if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist) builder.SetIPWhitelist(key.IPWhitelist)
@@ -118,6 +127,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey.FieldQuota, apikey.FieldQuota,
apikey.FieldQuotaUsed, apikey.FieldQuotaUsed,
apikey.FieldExpiresAt, apikey.FieldExpiresAt,
apikey.FieldRateLimit5h,
apikey.FieldRateLimit1d,
apikey.FieldRateLimit7d,
). ).
WithUser(func(q *dbent.UserQuery) { WithUser(func(q *dbent.UserQuery) {
q.Select( q.Select(
@@ -179,6 +191,12 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
SetStatus(key.Status). SetStatus(key.Status).
SetQuota(key.Quota). SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed). SetQuotaUsed(key.QuotaUsed).
SetRateLimit5h(key.RateLimit5h).
SetRateLimit1d(key.RateLimit1d).
SetRateLimit7d(key.RateLimit7d).
SetUsage5h(key.Usage5h).
SetUsage1d(key.Usage1d).
SetUsage7d(key.Usage7d).
SetUpdatedAt(now) SetUpdatedAt(now)
if key.GroupID != nil { if key.GroupID != nil {
builder.SetGroupID(*key.GroupID) builder.SetGroupID(*key.GroupID)
@@ -193,6 +211,23 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearExpiresAt() builder.ClearExpiresAt()
} }
// Rate limit window start times
if key.Window5hStart != nil {
builder.SetWindow5hStart(*key.Window5hStart)
} else {
builder.ClearWindow5hStart()
}
if key.Window1dStart != nil {
builder.SetWindow1dStart(*key.Window1dStart)
} else {
builder.ClearWindow1dStart()
}
if key.Window7dStart != nil {
builder.SetWindow7dStart(*key.Window7dStart)
} else {
builder.ClearWindow7dStart()
}
// IP 限制字段 // IP 限制字段
if len(key.IPWhitelist) > 0 { if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist) builder.SetIPWhitelist(key.IPWhitelist)
@@ -412,25 +447,92 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt
return nil return nil
} }
// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes
// window start times via COALESCE if not already set.
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = usage_5h + $1,
usage_1d = usage_1d + $1,
usage_7d = usage_7d + $1,
window_5h_start = COALESCE(window_5h_start, NOW()),
window_1d_start = COALESCE(window_1d_start, NOW()),
window_7d_start = COALESCE(window_7d_start, NOW()),
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL`,
cost, id)
return err
}
// ResetRateLimitWindows resets expired rate limit windows atomically.
func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`,
id)
return err
}
// GetRateLimitData returns the current rate limit usage and window start times for an API key.
func (r *apiKeyRepository) GetRateLimitData(ctx context.Context, id int64) (result *service.APIKeyRateLimitData, err error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start
FROM api_keys
WHERE id = $1 AND deleted_at IS NULL`,
id)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
}
}()
if !rows.Next() {
return nil, service.ErrAPIKeyNotFound
}
data := &service.APIKeyRateLimitData{}
if err := rows.Scan(&data.Usage5h, &data.Usage1d, &data.Usage7d, &data.Window5hStart, &data.Window1dStart, &data.Window7dStart); err != nil {
return nil, err
}
return data, rows.Err()
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil { if m == nil {
return nil return nil
} }
out := &service.APIKey{ out := &service.APIKey{
ID: m.ID, ID: m.ID,
UserID: m.UserID, UserID: m.UserID,
Key: m.Key, Key: m.Key,
Name: m.Name, Name: m.Name,
Status: m.Status, Status: m.Status,
IPWhitelist: m.IPWhitelist, IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist, IPBlacklist: m.IPBlacklist,
LastUsedAt: m.LastUsedAt, LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt, CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt, UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID, GroupID: m.GroupID,
Quota: m.Quota, Quota: m.Quota,
QuotaUsed: m.QuotaUsed, QuotaUsed: m.QuotaUsed,
ExpiresAt: m.ExpiresAt, ExpiresAt: m.ExpiresAt,
RateLimit5h: m.RateLimit5h,
RateLimit1d: m.RateLimit1d,
RateLimit7d: m.RateLimit7d,
Usage5h: m.Usage5h,
Usage1d: m.Usage1d,
Usage7d: m.Usage7d,
Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart,
} }
if m.Edges.User != nil { if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User) out.User = userEntityToService(m.Edges.User)

View File

@@ -26,7 +26,7 @@ func (s *APIKeyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
tx := testEntTx(s.T()) tx := testEntTx(s.T())
s.client = tx.Client() s.client = tx.Client()
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository) s.repo = newAPIKeyRepositoryWithSQL(s.client, tx)
} }
func TestAPIKeyRepoSuite(t *testing.T) { func TestAPIKeyRepoSuite(t *testing.T) {
@@ -421,7 +421,7 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
// 注意:此测试使用 testEntClient非事务隔离数据会真正写入数据库。 // 注意:此测试使用 testEntClient非事务隔离数据会真正写入数据库。
func TestIncrementQuotaUsed_Concurrent(t *testing.T) { func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
client := testEntClient(t) client := testEntClient(t)
repo := NewAPIKeyRepository(client).(*apiKeyRepository) repo := NewAPIKeyRepository(client, integrationDB).(*apiKeyRepository)
ctx := context.Background() ctx := context.Background()
// 创建测试用户和 API Key // 创建测试用户和 API Key

View File

@@ -14,10 +14,12 @@ import (
) )
const ( const (
billingBalanceKeyPrefix = "billing:balance:" billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:" billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute billingRateLimitKeyPrefix = "apikey:rate:"
billingCacheJitter = 30 * time.Second billingCacheTTL = 5 * time.Minute
billingCacheJitter = 30 * time.Second
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
) )
// jitteredTTL 返回带随机抖动的 TTL防止缓存雪崩 // jitteredTTL 返回带随机抖动的 TTL防止缓存雪崩
@@ -49,6 +51,20 @@ const (
subFieldVersion = "version" subFieldVersion = "version"
) )
// billingRateLimitKey generates the Redis key for API key rate limit cache.
func billingRateLimitKey(keyID int64) string {
return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID)
}
const (
rateLimitFieldUsage5h = "usage_5h"
rateLimitFieldUsage1d = "usage_1d"
rateLimitFieldUsage7d = "usage_7d"
rateLimitFieldWindow5h = "window_5h"
rateLimitFieldWindow1d = "window_1d"
rateLimitFieldWindow7d = "window_7d"
)
var ( var (
deductBalanceScript = redis.NewScript(` deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1]) local current = redis.call('GET', KEYS[1])
@@ -73,6 +89,21 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2]) redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1 return 1
`) `)
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
updateRateLimitUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
) )
type billingCache struct { type billingCache struct {
@@ -195,3 +226,69 @@ func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID,
key := billingSubKey(userID, groupID) key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) {
key := billingRateLimitKey(keyID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
data := &service.APIKeyRateLimitCacheData{}
if v, ok := result[rateLimitFieldUsage5h]; ok {
data.Usage5h, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldUsage1d]; ok {
data.Usage1d, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldUsage7d]; ok {
data.Usage7d, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldWindow5h]; ok {
data.Window5h, _ = strconv.ParseInt(v, 10, 64)
}
if v, ok := result[rateLimitFieldWindow1d]; ok {
data.Window1d, _ = strconv.ParseInt(v, 10, 64)
}
if v, ok := result[rateLimitFieldWindow7d]; ok {
data.Window7d, _ = strconv.ParseInt(v, 10, 64)
}
return data, nil
}
func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error {
if data == nil {
return nil
}
key := billingRateLimitKey(keyID)
fields := map[string]any{
rateLimitFieldUsage5h: data.Usage5h,
rateLimitFieldUsage1d: data.Usage1d,
rateLimitFieldUsage7d: data.Usage7d,
rateLimitFieldWindow5h: data.Window5h,
rateLimitFieldWindow1d: data.Window1d,
rateLimitFieldWindow7d: data.Window7d,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, rateLimitCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
key := billingRateLimitKey(keyID)
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
return err
}
return nil
}
func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
key := billingRateLimitKey(keyID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -66,6 +66,13 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {}, "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
}, },
}, },
"061_add_usage_log_request_type.sql": {
fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
acceptedDBChecksum: map[string]struct{}{
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {},
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
},
},
} }
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。

View File

@@ -25,6 +25,24 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
require.False(t, ok) require.False(t, ok)
}) })
t.Run("061历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"061_add_usage_log_request_type.sql",
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0",
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
)
require.True(t, ok)
})
t.Run("061第二个历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"061_add_usage_log_request_type.sql",
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3",
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
)
require.True(t, ok)
})
t.Run("非白名单迁移不兼容", func(t *testing.T) { t.Run("非白名单迁移不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible( ok := isMigrationChecksumCompatible(
"001_init.sql", "001_init.sql",

View File

@@ -41,7 +41,7 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewAPIKeyRepository(client) repo := NewAPIKeyRepository(client, integrationDB)
key := &service.APIKey{ key := &service.APIKey{
UserID: u.ID, UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"), Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
@@ -73,7 +73,7 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewAPIKeyRepository(client) repo := NewAPIKeyRepository(client, integrationDB)
key := &service.APIKey{ key := &service.APIKey{
UserID: u.ID, UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"), Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
@@ -93,7 +93,7 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewAPIKeyRepository(client) repo := NewAPIKeyRepository(client, integrationDB)
key := &service.APIKey{ key := &service.APIKey{
UserID: u.ID, UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"), Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),

View File

@@ -1655,6 +1655,13 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
// GetUsageTrendWithFilters returns usage trend data with optional filters // GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
if shouldUsePreaggregatedTrend(granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) {
aggregated, aggregatedErr := r.getUsageTrendFromAggregates(ctx, startTime, endTime, granularity)
if aggregatedErr == nil && len(aggregated) > 0 {
return aggregated, nil
}
}
dateFormat := safeDateFormat(granularity) dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(` query := fmt.Sprintf(`
@@ -1719,6 +1726,78 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
return results, nil return results, nil
} }
func shouldUsePreaggregatedTrend(granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) bool {
if granularity != "day" && granularity != "hour" {
return false
}
return userID == 0 &&
apiKeyID == 0 &&
accountID == 0 &&
groupID == 0 &&
model == "" &&
requestType == nil &&
stream == nil &&
billingType == nil
}
func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := safeDateFormat(granularity)
query := ""
args := []any{startTime, endTime}
switch granularity {
case "hour":
query = fmt.Sprintf(`
SELECT
TO_CHAR(bucket_start, '%s') as date,
total_requests as requests,
input_tokens,
output_tokens,
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
total_cost as cost,
actual_cost
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
ORDER BY bucket_start ASC
`, dateFormat)
case "day":
query = fmt.Sprintf(`
SELECT
TO_CHAR(bucket_date::timestamp, '%s') as date,
total_requests as requests,
input_tokens,
output_tokens,
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
total_cost as cost,
actual_cost
FROM usage_dashboard_daily
WHERE bucket_date >= $1::date AND bucket_date < $2::date
ORDER BY bucket_date ASC
`, dateFormat)
default:
return nil, nil
}
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results, err = scanTrendRows(rows)
if err != nil {
return nil, err
}
return results, nil
}
// GetModelStatsWithFilters returns model statistics with optional filters // GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"

View File

@@ -86,6 +86,15 @@ func TestAPIContracts(t *testing.T) {
"last_used_at": null, "last_used_at": null,
"quota": 0, "quota": 0,
"quota_used": 0, "quota_used": 0,
"rate_limit_5h": 0,
"rate_limit_1d": 0,
"rate_limit_7d": 0,
"usage_5h": 0,
"usage_1d": 0,
"usage_7d": 0,
"window_5h_start": null,
"window_1d_start": null,
"window_7d_start": null,
"expires_at": null, "expires_at": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
@@ -126,6 +135,15 @@ func TestAPIContracts(t *testing.T) {
"last_used_at": null, "last_used_at": null,
"quota": 0, "quota": 0,
"quota_used": 0, "quota_used": 0,
"rate_limit_5h": 0,
"rate_limit_1d": 0,
"rate_limit_7d": 0,
"usage_5h": 0,
"usage_1d": 0,
"usage_7d": 0,
"window_5h_start": null,
"window_1d_start": null,
"window_7d_start": null,
"expires_at": null, "expires_at": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
@@ -514,6 +532,7 @@ func TestAPIContracts(t *testing.T) {
"purchase_subscription_enabled": false, "purchase_subscription_enabled": false,
"purchase_subscription_url": "", "purchase_subscription_url": "",
"min_claude_code_version": "", "min_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"custom_menu_items": [] "custom_menu_items": []
} }
}`, }`,
@@ -1027,6 +1046,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
@@ -1498,6 +1525,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
return nil return nil
} }
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
return nil
}
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
return nil
}
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
return nil, nil
}
type stubUsageLogRepo struct { type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog userLogs map[int64][]service.UsageLog
} }

View File

@@ -19,8 +19,16 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
} }
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证 // apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
//
// 中间件职责分为两层:
// - 鉴权Authentication验证 Key 有效性、用户状态、IP 限制 —— 始终执行
// - 计费执行Billing Enforcement过期/配额/订阅/余额检查 —— skipBilling 时整块跳过
//
// /v1/usage 端点只需鉴权,不需要计费执行(允许过期/配额耗尽的 Key 查询自身用量)。
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// ── 1. 提取 API Key ──────────────────────────────────────────
queryKey := strings.TrimSpace(c.Query("key")) queryKey := strings.TrimSpace(c.Query("key"))
queryApiKey := strings.TrimSpace(c.Query("api_key")) queryApiKey := strings.TrimSpace(c.Query("api_key"))
if queryKey != "" || queryApiKey != "" { if queryKey != "" || queryApiKey != "" {
@@ -56,7 +64,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return return
} }
// 从数据库验证API key // ── 2. 验证 Key 存在 ─────────────────────────────────────────
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, service.ErrAPIKeyNotFound) { if errors.Is(err, service.ErrAPIKeyNotFound) {
@@ -67,29 +76,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return return
} }
// 检查API key是否激活 // ── 3. 基础鉴权(始终执行) ─────────────────────────────────
if !apiKey.IsActive() {
// Provide more specific error message based on status
switch apiKey.Status {
case service.StatusAPIKeyQuotaExhausted:
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
case service.StatusAPIKeyExpired:
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
default:
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
}
return
}
// 检查API Key是否过期即使状态是active也要检查时间 // disabled / 未知状态 → 无条件拦截expired 和 quota_exhausted 留给计费阶段
if apiKey.IsExpired() { if !apiKey.IsActive() &&
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") apiKey.Status != service.StatusAPIKeyExpired &&
return apiKey.Status != service.StatusAPIKeyQuotaExhausted {
} AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
// 检查API Key配额是否耗尽
if apiKey.IsQuotaExhausted() {
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
return return
} }
@@ -116,8 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return return
} }
// ── 4. SimpleMode → early return ─────────────────────────────
if cfg.RunMode == config.RunModeSimple { if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyAPIKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
@@ -130,54 +124,89 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return return
} }
// 判断计费方式:订阅模式 vs 余额模式 // ── 5. 加载订阅(订阅模式时始终加载) ───────────────────────
// skipBilling: /v1/usage 只需鉴权,跳过所有计费执行
skipBilling := c.Request.URL.Path == "/v1/usage"
var subscription *service.UserSubscription
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil { if isSubscriptionType && subscriptionService != nil {
// 订阅模式获取订阅L1 缓存 + singleflight sub, subErr := subscriptionService.GetActiveSubscription(
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(), c.Request.Context(),
apiKey.User.ID, apiKey.User.ID,
apiKey.Group.ID, apiKey.Group.ID,
) )
if err != nil { if subErr != nil {
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group") if !skipBilling {
return AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
} return
// 合并验证 + 限额检查(纯内存操作)
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if err != nil {
code := "SUBSCRIPTION_INVALID"
status := 403
if errors.Is(err, service.ErrDailyLimitExceeded) ||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
code = "USAGE_LIMIT_EXCEEDED"
status = 429
} }
AbortWithError(c, status, code, err.Error()) // skipBilling: 订阅不存在也放行handler 会返回可用的数据
return } else {
} subscription = sub
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
// 窗口维护异步化(不阻塞请求)
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if needsMaintenance {
maintenanceCopy := *subscription
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
} }
} }
// 将API key和用户信息存入上下文 // ── 6. 计费执行skipBilling 时整块跳过) ────────────────────
if !skipBilling {
// Key 状态检查
switch apiKey.Status {
case service.StatusAPIKeyQuotaExhausted:
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
return
case service.StatusAPIKeyExpired:
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
return
}
// 运行时过期/配额检查(即使状态是 active也要检查时间和用量
if apiKey.IsExpired() {
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
return
}
if apiKey.IsQuotaExhausted() {
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
return
}
// 订阅模式:验证订阅限额
if subscription != nil {
needsMaintenance, validateErr := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if validateErr != nil {
code := "SUBSCRIPTION_INVALID"
status := 403
if errors.Is(validateErr, service.ErrDailyLimitExceeded) ||
errors.Is(validateErr, service.ErrWeeklyLimitExceeded) ||
errors.Is(validateErr, service.ErrMonthlyLimitExceeded) {
code = "USAGE_LIMIT_EXCEEDED"
status = 429
}
AbortWithError(c, status, code, validateErr.Error())
return
}
// 窗口维护异步化(不阻塞请求)
if needsMaintenance {
maintenanceCopy := *subscription
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
// 非订阅模式 或 订阅模式但 subscriptionService 未注入:回退到余额检查
if apiKey.User.Balance <= 0 {
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
}
}
}
// ── 7. 设置上下文 → Next ─────────────────────────────────────
if subscription != nil {
c.Set(string(ContextKeySubscription), subscription)
}
c.Set(string(ContextKeyAPIKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,

View File

@@ -95,6 +95,15 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim
} }
return nil return nil
} }
func (f fakeAPIKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
return nil
}
func (f fakeAPIKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
return nil
}
func (f fakeAPIKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
return &service.APIKeyRateLimitData{}, nil
}
func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented") return errors.New("not implemented")

View File

@@ -588,6 +588,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
return nil return nil
} }
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
return nil
}
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
return nil
}
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
return nil, nil
}
type stubUserSubscriptionRepo struct { type stubUserSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error updateStatus func(ctx context.Context, subscriptionID int64, status string) error

View File

@@ -2,8 +2,11 @@ package middleware
import ( import (
"context" "context"
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -71,3 +74,48 @@ func AbortWithError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, NewErrorResponse(code, message)) c.JSON(statusCode, NewErrorResponse(code, message))
c.Abort() c.Abort()
} }
// ──────────────────────────────────────────────────────────
// RequireGroupAssignment — 未分组 Key 拦截中间件
// ──────────────────────────────────────────────────────────
// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式)
type GatewayErrorWriter func(c *gin.Context, status int, message string)
// AnthropicErrorWriter 按 Anthropic API 规范输出错误
func AnthropicErrorWriter(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": "permission_error", "message": message},
})
}
// GoogleErrorWriter 按 Google API 规范输出错误
func GoogleErrorWriter(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
}
// RequireGroupAssignment 检查 API Key 是否已分配到分组,
// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。
func RequireGroupAssignment(settingService *service.SettingService, writeError GatewayErrorWriter) gin.HandlerFunc {
return func(c *gin.Context) {
apiKey, ok := GetAPIKeyFromContext(c)
if !ok || apiKey.GroupID != nil {
c.Next()
return
}
// 未分组 Key — 检查系统设置
if settingService.IsUngroupedKeySchedulingAllowed(c.Request.Context()) {
c.Next()
return
}
writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.")
c.Abort()
}
}

View File

@@ -81,7 +81,7 @@ func SetupRouter(
} }
// 注册路由 // 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient) registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
return r return r
} }
@@ -96,6 +96,7 @@ func registerRoutes(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
opsService *service.OpsService, opsService *service.OpsService,
settingService *service.SettingService,
cfg *config.Config, cfg *config.Config,
redisClient *redis.Client, redisClient *redis.Client,
) { ) {
@@ -110,5 +111,5 @@ func registerRoutes(
routes.RegisterUserRoutes(v1, h, jwtAuth) routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth) routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
} }

View File

@@ -19,6 +19,7 @@ func RegisterGatewayRoutes(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
opsService *service.OpsService, opsService *service.OpsService,
settingService *service.SettingService,
cfg *config.Config, cfg *config.Config,
) { ) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
@@ -30,12 +31,17 @@ func RegisterGatewayRoutes(
clientRequestID := middleware.ClientRequestID() clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
requireGroupGoogle := middleware.RequireGroupAssignment(settingService, middleware.GoogleErrorWriter)
// API网关Claude API兼容 // API网关Claude API兼容
gateway := r.Group("/v1") gateway := r.Group("/v1")
gateway.Use(bodyLimit) gateway.Use(bodyLimit)
gateway.Use(clientRequestID) gateway.Use(clientRequestID)
gateway.Use(opsErrorLogger) gateway.Use(opsErrorLogger)
gateway.Use(gin.HandlerFunc(apiKeyAuth)) gateway.Use(gin.HandlerFunc(apiKeyAuth))
gateway.Use(requireGroupAnthropic)
{ {
gateway.POST("/messages", h.Gateway.Messages) gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens) gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
@@ -61,6 +67,7 @@ func RegisterGatewayRoutes(
gemini.Use(clientRequestID) gemini.Use(clientRequestID)
gemini.Use(opsErrorLogger) gemini.Use(opsErrorLogger)
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
gemini.Use(requireGroupGoogle)
{ {
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
@@ -69,11 +76,11 @@ func RegisterGatewayRoutes(
} }
// OpenAI Responses API不带v1前缀的别名 // OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// Antigravity 模型列表 // Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels) r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度) // Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1") antigravityV1 := r.Group("/antigravity/v1")
@@ -82,6 +89,7 @@ func RegisterGatewayRoutes(
antigravityV1.Use(opsErrorLogger) antigravityV1.Use(opsErrorLogger)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth)) antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
antigravityV1.Use(requireGroupAnthropic)
{ {
antigravityV1.POST("/messages", h.Gateway.Messages) antigravityV1.POST("/messages", h.Gateway.Messages)
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens) antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
@@ -95,6 +103,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta.Use(opsErrorLogger) antigravityV1Beta.Use(opsErrorLogger)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
antigravityV1Beta.Use(requireGroupGoogle)
{ {
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels) antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
@@ -108,6 +117,7 @@ func RegisterGatewayRoutes(
soraV1.Use(opsErrorLogger) soraV1.Use(opsErrorLogger)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth)) soraV1.Use(gin.HandlerFunc(apiKeyAuth))
soraV1.Use(requireGroupAnthropic)
{ {
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions) soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
soraV1.GET("/models", h.Gateway.Models) soraV1.GET("/models", h.Gateway.Models)

View File

@@ -54,6 +54,8 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error

View File

@@ -147,6 +147,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
panic("unexpected ListSchedulableByGroupIDAndPlatforms call") panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
} }
func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
panic("unexpected ListSchedulableUngroupedByPlatform call")
}
func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
panic("unexpected ListSchedulableUngroupedByPlatforms call")
}
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
panic("unexpected SetRateLimited call") panic("unexpected SetRateLimited call")
} }

View File

@@ -127,6 +127,15 @@ func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64
func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error { func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error {
panic("unexpected") panic("unexpected")
} }
func (s *apiKeyRepoStubForGroupUpdate) IncrementRateLimitUsage(context.Context, int64, float64) error {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, int64) error {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected")
}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests. // groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type groupRepoStubForGroupUpdate struct { type groupRepoStubForGroupUpdate struct {

View File

@@ -348,6 +348,19 @@ func (s *billingCacheStub) InvalidateSubscriptionCache(ctx context.Context, user
return nil return nil
} }
func (s *billingCacheStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
panic("unexpected GetAPIKeyRateLimit call")
}
func (s *billingCacheStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
panic("unexpected SetAPIKeyRateLimit call")
}
func (s *billingCacheStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
panic("unexpected UpdateAPIKeyRateLimitUsage call")
}
func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
panic("unexpected InvalidateAPIKeyRateLimit call")
}
func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall { func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall {
t.Helper() t.Helper()
calls := make([]subscriptionInvalidateCall, 0, expected) calls := make([]subscriptionInvalidateCall, 0, expected)

View File

@@ -36,12 +36,28 @@ type APIKey struct {
Quota float64 // Quota limit in USD (0 = unlimited) Quota float64 // Quota limit in USD (0 = unlimited)
QuotaUsed float64 // Used quota amount QuotaUsed float64 // Used quota amount
ExpiresAt *time.Time // Expiration time (nil = never expires) ExpiresAt *time.Time // Expiration time (nil = never expires)
// Rate limit fields
RateLimit5h float64 // Rate limit in USD per 5h (0 = unlimited)
RateLimit1d float64 // Rate limit in USD per 1d (0 = unlimited)
RateLimit7d float64 // Rate limit in USD per 7d (0 = unlimited)
Usage5h float64 // Used amount in current 5h window
Usage1d float64 // Used amount in current 1d window
Usage7d float64 // Used amount in current 7d window
Window5hStart *time.Time // Start of current 5h window
Window1dStart *time.Time // Start of current 1d window
Window7dStart *time.Time // Start of current 7d window
} }
func (k *APIKey) IsActive() bool { func (k *APIKey) IsActive() bool {
return k.Status == StatusActive return k.Status == StatusActive
} }
// HasRateLimits returns true if any rate limit window is configured
func (k *APIKey) HasRateLimits() bool {
return k.RateLimit5h > 0 || k.RateLimit1d > 0 || k.RateLimit7d > 0
}
// IsExpired checks if the API key has expired // IsExpired checks if the API key has expired
func (k *APIKey) IsExpired() bool { func (k *APIKey) IsExpired() bool {
if k.ExpiresAt == nil { if k.ExpiresAt == nil {

View File

@@ -19,6 +19,11 @@ type APIKeyAuthSnapshot struct {
// Expiration field for API Key expiration feature // Expiration field for API Key expiration feature
ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires) ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires)
// Rate limit configuration (only limits, not usage - usage read from Redis at check time)
RateLimit5h float64 `json:"rate_limit_5h"`
RateLimit1d float64 `json:"rate_limit_1d"`
RateLimit7d float64 `json:"rate_limit_7d"`
} }
// APIKeyAuthUserSnapshot 用户快照 // APIKeyAuthUserSnapshot 用户快照

View File

@@ -209,6 +209,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
Quota: apiKey.Quota, Quota: apiKey.Quota,
QuotaUsed: apiKey.QuotaUsed, QuotaUsed: apiKey.QuotaUsed,
ExpiresAt: apiKey.ExpiresAt, ExpiresAt: apiKey.ExpiresAt,
RateLimit5h: apiKey.RateLimit5h,
RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d,
User: APIKeyAuthUserSnapshot{ User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID, ID: apiKey.User.ID,
Status: apiKey.User.Status, Status: apiKey.User.Status,
@@ -262,6 +265,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
Quota: snapshot.Quota, Quota: snapshot.Quota,
QuotaUsed: snapshot.QuotaUsed, QuotaUsed: snapshot.QuotaUsed,
ExpiresAt: snapshot.ExpiresAt, ExpiresAt: snapshot.ExpiresAt,
RateLimit5h: snapshot.RateLimit5h,
RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d,
User: &User{ User: &User{
ID: snapshot.User.ID, ID: snapshot.User.ID,
Status: snapshot.User.Status, Status: snapshot.User.Status,

View File

@@ -30,6 +30,11 @@ var (
ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期") ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期")
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted") // ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完") ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完")
// Rate limit errors
ErrAPIKeyRateLimit5hExceeded = infraerrors.TooManyRequests("API_KEY_RATE_5H_EXCEEDED", "api key 5小时限额已用完")
ErrAPIKeyRateLimit1dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_1D_EXCEEDED", "api key 日限额已用完")
ErrAPIKeyRateLimit7dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_7D_EXCEEDED", "api key 7天限额已用完")
) )
const ( const (
@@ -64,6 +69,21 @@ type APIKeyRepository interface {
// Quota methods // Quota methods
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error
// Rate limit methods
IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error
ResetRateLimitWindows(ctx context.Context, id int64) error
GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error)
}
// APIKeyRateLimitData holds rate limit usage and window state for an API key.
type APIKeyRateLimitData struct {
Usage5h float64
Usage1d float64
Usage7d float64
Window5hStart *time.Time
Window1dStart *time.Time
Window7dStart *time.Time
} }
// APIKeyCache defines cache operations for API key service // APIKeyCache defines cache operations for API key service
@@ -102,6 +122,11 @@ type CreateAPIKeyRequest struct {
// Quota fields // Quota fields
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires) ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires)
// Rate limit fields (0 = unlimited)
RateLimit5h float64 `json:"rate_limit_5h"`
RateLimit1d float64 `json:"rate_limit_1d"`
RateLimit7d float64 `json:"rate_limit_7d"`
} }
// UpdateAPIKeyRequest 更新API Key请求 // UpdateAPIKeyRequest 更新API Key请求
@@ -117,22 +142,34 @@ type UpdateAPIKeyRequest struct {
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change) ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change)
ClearExpiration bool `json:"-"` // Clear expiration (internal use) ClearExpiration bool `json:"-"` // Clear expiration (internal use)
ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0 ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0
// Rate limit fields (nil = no change, 0 = unlimited)
RateLimit5h *float64 `json:"rate_limit_5h"`
RateLimit1d *float64 `json:"rate_limit_1d"`
RateLimit7d *float64 `json:"rate_limit_7d"`
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0
} }
// APIKeyService API Key服务 // APIKeyService API Key服务
// RateLimitCacheInvalidator invalidates rate limit cache entries on manual reset.
type RateLimitCacheInvalidator interface {
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
}
type APIKeyService struct { type APIKeyService struct {
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
userRepo UserRepository userRepo UserRepository
groupRepo GroupRepository groupRepo GroupRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository userGroupRateRepo UserGroupRateRepository
cache APIKeyCache cache APIKeyCache
cfg *config.Config rateLimitCacheInvalid RateLimitCacheInvalidator // optional: invalidate Redis rate limit cache
authCacheL1 *ristretto.Cache cfg *config.Config
authCfg apiKeyAuthCacheConfig authCacheL1 *ristretto.Cache
authGroup singleflight.Group authCfg apiKeyAuthCacheConfig
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time) authGroup singleflight.Group
lastUsedTouchSF singleflight.Group lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
lastUsedTouchSF singleflight.Group
} }
// NewAPIKeyService 创建API Key服务实例 // NewAPIKeyService 创建API Key服务实例
@@ -158,6 +195,12 @@ func NewAPIKeyService(
return svc return svc
} }
// SetRateLimitCacheInvalidator sets the optional rate limit cache invalidator.
// Called after construction (e.g. in wire) to avoid circular dependencies.
func (s *APIKeyService) SetRateLimitCacheInvalidator(inv RateLimitCacheInvalidator) {
s.rateLimitCacheInvalid = inv
}
func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) { func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
if apiKey == nil { if apiKey == nil {
return return
@@ -327,6 +370,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
IPBlacklist: req.IPBlacklist, IPBlacklist: req.IPBlacklist,
Quota: req.Quota, Quota: req.Quota,
QuotaUsed: 0, QuotaUsed: 0,
RateLimit5h: req.RateLimit5h,
RateLimit1d: req.RateLimit1d,
RateLimit7d: req.RateLimit7d,
} }
// Set expiration time if specified // Set expiration time if specified
@@ -519,6 +565,26 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
apiKey.IPWhitelist = req.IPWhitelist apiKey.IPWhitelist = req.IPWhitelist
apiKey.IPBlacklist = req.IPBlacklist apiKey.IPBlacklist = req.IPBlacklist
// Update rate limit configuration
if req.RateLimit5h != nil {
apiKey.RateLimit5h = *req.RateLimit5h
}
if req.RateLimit1d != nil {
apiKey.RateLimit1d = *req.RateLimit1d
}
if req.RateLimit7d != nil {
apiKey.RateLimit7d = *req.RateLimit7d
}
resetRateLimit := req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage
if resetRateLimit {
apiKey.Usage5h = 0
apiKey.Usage1d = 0
apiKey.Usage7d = 0
apiKey.Window5hStart = nil
apiKey.Window1dStart = nil
apiKey.Window7dStart = nil
}
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err) return nil, fmt.Errorf("update api key: %w", err)
} }
@@ -526,6 +592,11 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
s.InvalidateAuthCacheByKey(ctx, apiKey.Key) s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey) s.compileAPIKeyIPRules(apiKey)
// Invalidate Redis rate limit cache so reset takes effect immediately
if resetRateLimit && s.rateLimitCacheInvalid != nil {
_ = s.rateLimitCacheInvalid.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
}
return apiKey, nil return apiKey, nil
} }
@@ -746,3 +817,16 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
return nil return nil
} }
// GetRateLimitData returns rate limit usage and window state for an API key.
func (s *APIKeyService) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
return s.apiKeyRepo.GetRateLimitData(ctx, id)
}
// UpdateRateLimitUsage atomically increments rate limit usage counters in the DB.
func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
if cost <= 0 {
return nil
}
return s.apiKeyRepo.IncrementRateLimitUsage(ctx, apiKeyID, cost)
}

View File

@@ -106,6 +106,15 @@ func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount
func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
panic("unexpected UpdateLastUsed call") panic("unexpected UpdateLastUsed call")
} }
func (s *authRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *authRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *authRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
type authCacheStub struct { type authCacheStub struct {
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)

View File

@@ -134,6 +134,18 @@ func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
return nil return nil
} }
func (s *apiKeyRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *apiKeyRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *apiKeyRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。
// //

View File

@@ -40,6 +40,7 @@ const (
cacheWriteSetSubscription cacheWriteSetSubscription
cacheWriteUpdateSubscriptionUsage cacheWriteUpdateSubscriptionUsage
cacheWriteDeductBalance cacheWriteDeductBalance
cacheWriteUpdateRateLimitUsage
) )
// 异步缓存写入工作池配置 // 异步缓存写入工作池配置
@@ -68,19 +69,26 @@ type cacheWriteTask struct {
kind cacheWriteKind kind cacheWriteKind
userID int64 userID int64
groupID int64 groupID int64
apiKeyID int64
balance float64 balance float64
amount float64 amount float64
subscriptionData *subscriptionCacheData subscriptionData *subscriptionCacheData
} }
// apiKeyRateLimitLoader defines the interface for loading rate limit data from DB.
type apiKeyRateLimitLoader interface {
GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error)
}
// BillingCacheService 计费缓存服务 // BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct { type BillingCacheService struct {
cache BillingCache cache BillingCache
userRepo UserRepository userRepo UserRepository
subRepo UserSubscriptionRepository subRepo UserSubscriptionRepository
cfg *config.Config apiKeyRateLimitLoader apiKeyRateLimitLoader
circuitBreaker *billingCircuitBreaker cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup cacheWriteWg sync.WaitGroup
@@ -96,12 +104,13 @@ type BillingCacheService struct {
} }
// NewBillingCacheService 创建计费缓存服务 // NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService { func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
svc := &BillingCacheService{ svc := &BillingCacheService{
cache: cache, cache: cache,
userRepo: userRepo, userRepo: userRepo,
subRepo: subRepo, subRepo: subRepo,
cfg: cfg, apiKeyRateLimitLoader: apiKeyRepo,
cfg: cfg,
} }
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers() svc.startCacheWriteWorkers()
@@ -188,6 +197,12 @@ func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err) logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
} }
} }
case cacheWriteUpdateRateLimitUsage:
if s.cache != nil {
if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err)
}
}
} }
cancel() cancel()
} }
@@ -204,6 +219,8 @@ func cacheWriteKindName(kind cacheWriteKind) string {
return "update_subscription_usage" return "update_subscription_usage"
case cacheWriteDeductBalance: case cacheWriteDeductBalance:
return "deduct_balance" return "deduct_balance"
case cacheWriteUpdateRateLimitUsage:
return "update_rate_limit_usage"
default: default:
return "unknown" return "unknown"
} }
@@ -476,6 +493,137 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
return nil return nil
} }
// ============================================
// API Key 限速缓存方法
// ============================================
// checkAPIKeyRateLimits checks rate limit windows for an API key.
// It loads usage from Redis cache (falling back to DB on cache miss),
// resets expired windows in-memory and triggers async DB reset,
// and returns an error if any window limit is exceeded.
func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error {
if s.cache == nil {
// No cache: fall back to reading from DB directly
if s.apiKeyRateLimitLoader == nil {
return nil
}
data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
if err != nil {
return nil // Don't block requests on DB errors
}
return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d,
data.Window5hStart, data.Window1dStart, data.Window7dStart)
}
cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID)
if err != nil {
// Cache miss: load from DB and populate cache
if s.apiKeyRateLimitLoader == nil {
return nil
}
dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
if dbErr != nil {
return nil // Don't block requests on DB errors
}
// Build cache entry from DB data
cacheEntry := &APIKeyRateLimitCacheData{
Usage5h: dbData.Usage5h,
Usage1d: dbData.Usage1d,
Usage7d: dbData.Usage7d,
}
if dbData.Window5hStart != nil {
cacheEntry.Window5h = dbData.Window5hStart.Unix()
}
if dbData.Window1dStart != nil {
cacheEntry.Window1d = dbData.Window1dStart.Unix()
}
if dbData.Window7dStart != nil {
cacheEntry.Window7d = dbData.Window7dStart.Unix()
}
_ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry)
cacheData = cacheEntry
}
var w5h, w1d, w7d *time.Time
if cacheData.Window5h > 0 {
t := time.Unix(cacheData.Window5h, 0)
w5h = &t
}
if cacheData.Window1d > 0 {
t := time.Unix(cacheData.Window1d, 0)
w1d = &t
}
if cacheData.Window7d > 0 {
t := time.Unix(cacheData.Window7d, 0)
w7d = &t
}
return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d)
}
// evaluateRateLimits checks usage against limits, triggering async resets for expired windows.
func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error {
needsReset := false
// Reset expired windows in-memory for check purposes
if w5h != nil && time.Since(*w5h) >= 5*time.Hour {
usage5h = 0
needsReset = true
}
if w1d != nil && time.Since(*w1d) >= 24*time.Hour {
usage1d = 0
needsReset = true
}
if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour {
usage7d = 0
needsReset = true
}
// Trigger async DB reset if any window expired
if needsReset {
keyID := apiKey.ID
go func() {
resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if s.apiKeyRateLimitLoader != nil {
// Use the repo directly - reset then reload cache
if loader, ok := s.apiKeyRateLimitLoader.(interface {
ResetRateLimitWindows(ctx context.Context, id int64) error
}); ok {
_ = loader.ResetRateLimitWindows(resetCtx, keyID)
}
}
// Invalidate cache so next request loads fresh data
if s.cache != nil {
_ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID)
}
}()
}
// Check limits
if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h {
return ErrAPIKeyRateLimit5hExceeded
}
if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d {
return ErrAPIKeyRateLimit1dExceeded
}
if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d {
return ErrAPIKeyRateLimit7dExceeded
}
return nil
}
// QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache.
func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) {
if s.cache == nil {
return
}
s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteUpdateRateLimitUsage,
apiKeyID: apiKeyID,
amount: cost,
})
}
// ============================================ // ============================================
// 统一检查方法 // 统一检查方法
// ============================================ // ============================================
@@ -496,10 +644,23 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
if isSubscriptionMode { if isSubscriptionMode {
return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription) if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil {
return err
}
} else {
if err := s.checkBalanceEligibility(ctx, user.ID); err != nil {
return err
}
} }
return s.checkBalanceEligibility(ctx, user.ID) // Check API Key rate limits (applies to both billing modes)
if apiKey != nil && apiKey.HasRateLimits() {
if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil {
return err
}
}
return nil
} }
// checkBalanceEligibility 检查余额模式资格 // checkBalanceEligibility 检查余额模式资格

View File

@@ -51,6 +51,22 @@ func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context,
return nil return nil
} }
func (s *billingCacheMissStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
return nil
}
type balanceLoadUserRepoStub struct { type balanceLoadUserRepoStub struct {
mockUserRepo mockUserRepo
calls atomic.Int64 calls atomic.Int64
@@ -76,7 +92,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
delay: 80 * time.Millisecond, delay: 80 * time.Millisecond,
balance: 12.34, balance: 12.34,
} }
svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{}) svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
t.Cleanup(svc.Stop) t.Cleanup(svc.Stop)
const goroutines = 16 const goroutines = 16

View File

@@ -52,9 +52,25 @@ func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context
return nil return nil
} }
func (b *billingCacheWorkerStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
return nil, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
return nil
}
func (b *billingCacheWorkerStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
return nil
}
func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
return nil
}
func TestBillingCacheServiceQueueHighLoad(t *testing.T) { func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{} cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop) t.Cleanup(svc.Stop)
start := time.Now() start := time.Now()
@@ -76,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{} cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
svc.Stop() svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{ enqueued := svc.enqueueCacheWrite(cacheWriteTask{

View File

@@ -10,6 +10,16 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
) )
// APIKeyRateLimitCacheData holds rate limit usage data cached in Redis.
type APIKeyRateLimitCacheData struct {
Usage5h float64 `json:"usage_5h"`
Usage1d float64 `json:"usage_1d"`
Usage7d float64 `json:"usage_7d"`
Window5h int64 `json:"window_5h"` // unix timestamp, 0 = not started
Window1d int64 `json:"window_1d"`
Window7d int64 `json:"window_7d"`
}
// BillingCache defines cache operations for billing service // BillingCache defines cache operations for billing service
type BillingCache interface { type BillingCache interface {
// Balance operations // Balance operations
@@ -23,6 +33,12 @@ type BillingCache interface {
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
// API Key rate limit operations
GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error)
SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error
UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
} }
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致 // ModelPricing 模型价格配置per-token价格与LiteLLM格式一致

View File

@@ -201,6 +201,9 @@ const (
// SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查) // SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查)
SettingKeyMinClaudeCodeVersion = "min_claude_code_version" SettingKeyMinClaudeCodeVersion = "min_claude_code_version"
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false未分组 Key 返回 403
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
) )
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -0,0 +1,363 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ============================================================================
// Part 1: isAccountInGroup 单元测试
// ============================================================================
func TestIsAccountInGroup(t *testing.T) {
svc := &GatewayService{}
groupID100 := int64(100)
groupID200 := int64(200)
tests := []struct {
name string
account *Account
groupID *int64
expected bool
}{
// groupID == nil无分组 API Key
{
"nil_groupID_ungrouped_account_nil_groups",
&Account{ID: 1, AccountGroups: nil},
nil, true,
},
{
"nil_groupID_ungrouped_account_empty_slice",
&Account{ID: 2, AccountGroups: []AccountGroup{}},
nil, true,
},
{
"nil_groupID_grouped_account_single",
&Account{ID: 3, AccountGroups: []AccountGroup{{GroupID: 100}}},
nil, false,
},
{
"nil_groupID_grouped_account_multiple",
&Account{ID: 4, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
nil, false,
},
// groupID != nil有分组 API Key
{
"with_groupID_account_in_group",
&Account{ID: 5, AccountGroups: []AccountGroup{{GroupID: 100}}},
&groupID100, true,
},
{
"with_groupID_account_not_in_group",
&Account{ID: 6, AccountGroups: []AccountGroup{{GroupID: 200}}},
&groupID100, false,
},
{
"with_groupID_ungrouped_account",
&Account{ID: 7, AccountGroups: nil},
&groupID100, false,
},
{
"with_groupID_multi_group_account_match_one",
&Account{ID: 8, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
&groupID200, true,
},
{
"with_groupID_multi_group_account_no_match",
&Account{ID: 9, AccountGroups: []AccountGroup{{GroupID: 300}, {GroupID: 400}}},
&groupID100, false,
},
// 防御性边界
{
"nil_account_nil_groupID",
nil,
nil, false,
},
{
"nil_account_with_groupID",
nil,
&groupID100, false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.isAccountInGroup(tt.account, tt.groupID)
require.Equal(t, tt.expected, got, "isAccountInGroup 结果不符预期")
})
}
}
// ============================================================================
// Part 2: 分组隔离端到端调度测试
// ============================================================================
// groupAwareMockAccountRepo 嵌入 mockAccountRepoForPlatform覆写分组隔离相关方法。
// allAccounts 存储所有账号,分组查询方法按 AccountGroups 字段进行真实过滤。
type groupAwareMockAccountRepo struct {
*mockAccountRepoForPlatform
allAccounts []Account
}
// ListSchedulableUngroupedByPlatform 仅返回未分组账号AccountGroups 为空)
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range m.allAccounts {
if acc.Platform == platform && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
result = append(result, acc)
}
}
return result, nil
}
// ListSchedulableUngroupedByPlatforms 仅返回未分组账号(多平台版本)
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
platformSet := make(map[string]bool, len(platforms))
for _, p := range platforms {
platformSet[p] = true
}
var result []Account
for _, acc := range m.allAccounts {
if platformSet[acc.Platform] && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
result = append(result, acc)
}
}
return result, nil
}
// ListSchedulableByGroupIDAndPlatform 返回属于指定分组的账号
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
var result []Account
for _, acc := range m.allAccounts {
if acc.Platform == platform && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
result = append(result, acc)
}
}
return result, nil
}
// ListSchedulableByGroupIDAndPlatforms 返回属于指定分组的账号(多平台版本)
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
platformSet := make(map[string]bool, len(platforms))
for _, p := range platforms {
platformSet[p] = true
}
var result []Account
for _, acc := range m.allAccounts {
if platformSet[acc.Platform] && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
result = append(result, acc)
}
}
return result, nil
}
// accountBelongsToGroup 检查账号是否属于指定分组
func accountBelongsToGroup(acc Account, groupID int64) bool {
for _, ag := range acc.AccountGroups {
if ag.GroupID == groupID {
return true
}
}
return false
}
// Verify interface implementation
var _ AccountRepository = (*groupAwareMockAccountRepo)(nil)
// newGroupAwareMockRepo 创建分组感知的 mock repo
func newGroupAwareMockRepo(accounts []Account) *groupAwareMockAccountRepo {
byID := make(map[int64]*Account, len(accounts))
for i := range accounts {
byID[accounts[i].ID] = &accounts[i]
}
return &groupAwareMockAccountRepo{
mockAccountRepoForPlatform: &mockAccountRepoForPlatform{
accounts: accounts,
accountsByID: byID,
},
allAccounts: accounts,
}
}
func TestGroupIsolation_UngroupedKey_ShouldNotScheduleGroupedAccounts(t *testing.T) {
// 场景:无分组 API KeygroupID=nil池中只有已分组账号 → 应返回错误
ctx := context.Background()
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}},
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 200}}},
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.Error(t, err, "无分组 Key 不应调度到已分组账号")
require.Nil(t, acc)
}
func TestGroupIsolation_GroupedKey_ShouldNotScheduleUngroupedAccounts(t *testing.T) {
// 场景:有分组 API KeygroupID=100池中只有未分组账号 → 应返回错误
ctx := context.Background()
groupID := int64(100)
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: nil},
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{}},
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
require.Error(t, err, "有分组 Key 不应调度到未分组账号")
require.Nil(t, acc)
}
func TestGroupIsolation_UngroupedKey_ShouldOnlyScheduleUngroupedAccounts(t *testing.T) {
// 场景:无分组 API KeygroupID=nil池中有未分组和已分组账号 → 应只选中未分组的
ctx := context.Background()
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组,不应被选中
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: nil}, // 未分组,应被选中
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 已分组,不应被选中
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "应成功调度未分组账号")
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选中未分组的账号 ID=2")
}
func TestGroupIsolation_GroupedKey_ShouldOnlyScheduleMatchingGroupAccounts(t *testing.T) {
// 场景:有分组 API KeygroupID=100池中有未分组和多个分组账号 → 应只选中分组 100 内的
ctx := context.Background()
groupID := int64(100)
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: nil}, // 未分组,不应被选中
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 属于分组 200不应被选中
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 属于分组 100应被选中
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "应成功调度分组内账号")
require.NotNil(t, acc)
require.Equal(t, int64(3), acc.ID, "应选中分组 100 内的账号 ID=3")
}
// ============================================================================
// Part 3: SimpleMode 旁路测试
// ============================================================================
func TestGroupIsolation_SimpleMode_SkipsGroupIsolation(t *testing.T) {
// SimpleMode 应跳过分组隔离,使用 ListSchedulableByPlatform 返回所有账号。
// 测试非 useMixed 路径platform=openai不会触发 mixed 调度逻辑)。
ctx := context.Background()
// 混合未分组和已分组账号SimpleMode 下应全部可调度
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组
{ID: 2, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: nil}, // 未分组
}
// 使用基础 mockListSchedulableByPlatform 返回所有匹配平台的账号,不做分组过滤)
byID := make(map[int64]*Account, len(accounts))
for i := range accounts {
byID[accounts[i].ID] = &accounts[i]
}
repo := &mockAccountRepoForPlatform{
accounts: accounts,
accountsByID: byID,
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: &config.Config{RunMode: config.RunModeSimple},
}
// groupID=nil 时SimpleMode 应使用 ListSchedulableByPlatform不过滤分组
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "SimpleMode 应跳过分组隔离直接返回账号")
require.NotNil(t, acc)
// 应选择优先级最高的账号Priority=1, ID=2即使它未分组
require.Equal(t, int64(2), acc.ID, "SimpleMode 应按优先级选择,不考虑分组")
}
func TestGroupIsolation_SimpleMode_GroupedAccountAlsoSchedulable(t *testing.T) {
// SimpleMode + groupID=nil 时,已分组账号也应该可被调度
ctx := context.Background()
// 只有已分组账号,在 standard 模式下 groupID=nil 会报错,但 simple 模式应正常
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}},
}
byID := make(map[int64]*Account, len(accounts))
for i := range accounts {
byID[accounts[i].ID] = &accounts[i]
}
repo := &mockAccountRepoForPlatform{
accounts: accounts,
accountsByID: byID,
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: &config.Config{RunMode: config.RunModeSimple},
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "SimpleMode 下已分组账号也应可调度")
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "SimpleMode 应能调度已分组账号")
}

View File

@@ -147,6 +147,12 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Cont
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms) return m.ListSchedulableByPlatforms(ctx, platforms)
} }
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return m.ListSchedulableByPlatform(ctx, platform)
}
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }

View File

@@ -1782,8 +1782,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else { } else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms)
} }
if err != nil { if err != nil {
slog.Debug("account_scheduling_list_failed", slog.Debug("account_scheduling_list_failed",
@@ -1824,7 +1826,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
// 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform)
} }
if err != nil { if err != nil {
slog.Debug("account_scheduling_list_failed", slog.Debug("account_scheduling_list_failed",
@@ -1964,14 +1966,15 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
} }
// isAccountInGroup checks if the account belongs to the specified group. // isAccountInGroup checks if the account belongs to the specified group.
// Returns true if groupID is nil (no group restriction) or account belongs to the group. // When groupID is nil, returns true only for ungrouped accounts (no group assignments).
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
if groupID == nil {
return true // 无分组限制
}
if account == nil { if account == nil {
return false return false
} }
if groupID == nil {
// 无分组的 API Key 只能使用未分组的账号
return len(account.AccountGroups) == 0
}
for _, ag := range account.AccountGroups { for _, ag := range account.AccountGroups {
if ag.GroupID == *groupID { if ag.GroupID == *groupID {
return true return true
@@ -6361,9 +6364,10 @@ type RecordUsageInput struct {
APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额 APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额
} }
// APIKeyQuotaUpdater defines the interface for updating API Key quota // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
type APIKeyQuotaUpdater interface { type APIKeyQuotaUpdater interface {
UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
} }
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -6557,6 +6561,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
} }
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// Schedule batch update for account last_used_at // Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID) s.deferredService.ScheduleLastUsedUpdate(account.ID)
@@ -6746,6 +6758,14 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
} }
} }
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// Schedule batch update for account last_used_at // Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID) s.deferredService.ScheduleLastUsedUpdate(account.ID)

View File

@@ -431,7 +431,10 @@ func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Co
if groupID != nil { if groupID != nil {
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms) return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
} }
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms)
} }
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) { func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {

View File

@@ -138,6 +138,12 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
} }
return m.ListSchedulableByPlatforms(ctx, platforms) return m.ListSchedulableByPlatforms(ctx, platforms)
} }
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return m.ListSchedulableByPlatform(ctx, platform)
}
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }

View File

@@ -1343,7 +1343,7 @@ func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, grou
} else if groupID != nil { } else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -3492,6 +3492,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
} }
} }
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// Schedule batch update for account last_used_at // Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID) s.deferredService.ScheduleLastUsedUpdate(account.ID)

View File

@@ -57,6 +57,10 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl
return result, nil return result, nil
} }
func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
type stubConcurrencyCache struct { type stubConcurrencyCache struct {
ConcurrencyCache ConcurrencyCache
loadBatchErr error loadBatchErr error

View File

@@ -368,7 +368,7 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
Aggregation: OpsAggregationSettings{ Aggregation: OpsAggregationSettings{
AggregationEnabled: false, AggregationEnabled: false,
}, },
IgnoreCountTokensErrors: false, IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
AutoRefreshEnabled: false, AutoRefreshEnabled: false,

View File

@@ -146,13 +146,29 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} else { } else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
} }
// 3. 临时不可调度,替代 SetError保持 status=active 让刷新服务能拾取)
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "OAuth 401: " + upstreamMsg
}
cooldownMinutes := s.cfg.RateLimit.OAuth401CooldownMinutes
if cooldownMinutes <= 0 {
cooldownMinutes = 10
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil {
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
}
shouldDisable = true
} else {
// 非 OAuth 账号APIKey保持原有 SetError 行为
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
shouldDisable = true
} }
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 402: case 402:
// 支付要求:余额不足或计费问题,停止调度 // 支付要求:余额不足或计费问题,停止调度
msg := "Payment required (402): insufficient balance or billing issue" msg := "Payment required (402): insufficient balance or billing issue"

View File

@@ -41,7 +41,7 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
return r.err return r.err
} }
func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) { func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
platform string platform string
@@ -76,9 +76,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable) require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls) require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls) require.Equal(t, 1, repo.tempCalls)
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
require.Len(t, invalidator.accounts, 1) require.Len(t, invalidator.accounts, 1)
}) })
} }
@@ -98,7 +97,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable) require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls) require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1) require.Len(t, invalidator.accounts, 1)
} }

View File

@@ -605,8 +605,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke
var err error var err error
if groupID > 0 { if groupID > 0 {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms)
} else { } else if s.isRunModeSimple() {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -624,7 +626,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke
if groupID > 0 { if groupID > 0 {
return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform) return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform)
} }
return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) if s.isRunModeSimple() {
return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform)
}
return s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, bucket.Platform)
} }
func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket { func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket {

View File

@@ -438,6 +438,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// Claude Code version check // Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
// 分组隔离
updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling)
err = s.settingRepo.SetMultiple(ctx, updates) err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil { if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
@@ -646,6 +649,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// Claude Code version check (default: empty = disabled) // Claude Code version check (default: empty = disabled)
SettingKeyMinClaudeCodeVersion: "", SettingKeyMinClaudeCodeVersion: "",
// 分组隔离(默认不允许未分组 Key 调度)
SettingKeyAllowUngroupedKeyScheduling: "false",
} }
return s.settingRepo.SetMultiple(ctx, defaults) return s.settingRepo.SetMultiple(ctx, defaults)
@@ -776,6 +782,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Claude Code version check // Claude Code version check
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
// 分组隔离
result.AllowUngroupedKeyScheduling = settings[SettingKeyAllowUngroupedKeyScheduling] == "true"
return result return result
} }
@@ -1098,6 +1107,15 @@ func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamT
return &settings, nil return &settings, nil
} }
// IsUngroupedKeySchedulingAllowed 查询是否允许未分组 Key 调度
func (s *SettingService) IsUngroupedKeySchedulingAllowed(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyAllowUngroupedKeyScheduling)
if err != nil {
return false // fail-closed: 查询失败时默认不允许
}
return value == "true"
}
// GetMinClaudeCodeVersion 获取最低 Claude Code 版本号要求 // GetMinClaudeCodeVersion 获取最低 Claude Code 版本号要求
// 使用进程内 atomic.Value 缓存60 秒 TTL热路径零锁开销 // 使用进程内 atomic.Value 缓存60 秒 TTL热路径零锁开销
// singleflight 防止缓存过期时 thundering herd // singleflight 防止缓存过期时 thundering herd

View File

@@ -65,6 +65,9 @@ type SystemSettings struct {
// Claude Code version check // Claude Code version check
MinClaudeCodeVersion string MinClaudeCodeVersion string
// 分组隔离:允许未分组 Key 调度(默认 false → 403
AllowUngroupedKeyScheduling bool
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {

View File

@@ -18,7 +18,8 @@ type TokenRefreshService struct {
refreshers []TokenRefresher refreshers []TokenRefresher
cfg *config.TokenRefreshConfig cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator cacheInvalidator TokenCacheInvalidator
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题 schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
stopCh chan struct{} stopCh chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
@@ -34,12 +35,14 @@ func NewTokenRefreshService(
cacheInvalidator TokenCacheInvalidator, cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache, schedulerCache SchedulerCache,
cfg *config.Config, cfg *config.Config,
tempUnschedCache TempUnschedCache,
) *TokenRefreshService { ) *TokenRefreshService {
s := &TokenRefreshService{ s := &TokenRefreshService{
accountRepo: accountRepo, accountRepo: accountRepo,
cfg: &cfg.TokenRefresh, cfg: &cfg.TokenRefresh,
cacheInvalidator: cacheInvalidator, cacheInvalidator: cacheInvalidator,
schedulerCache: schedulerCache, schedulerCache: schedulerCache,
tempUnschedCache: tempUnschedCache,
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
} }
@@ -231,6 +234,26 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID) slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
} }
} }
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_temp_unschedulable_failed",
"account_id", account.ID,
"error", clearErr,
)
} else {
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
}
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
if s.tempUnschedCache != nil {
if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_temp_unsched_cache_failed",
"account_id", account.ID,
"error", clearErr,
)
}
}
}
// 对所有 OAuth 账号调用缓存失效InvalidateToken 内部根据平台判断是否需要处理) // 对所有 OAuth 账号调用缓存失效InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil { if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
@@ -257,8 +280,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
return nil return nil
} }
// Antigravity 账户:不可重试错误直接标记 error 状态并返回 // 不可重试错误invalid_grant/invalid_client 等)直接标记 error 状态并返回
if account.Platform == PlatformAntigravity && isNonRetryableRefreshError(err) { if isNonRetryableRefreshError(err) {
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err) errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil { if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
slog.Error("token_refresh.set_error_status_failed", slog.Error("token_refresh.set_error_status_failed",
@@ -285,23 +308,13 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
} }
} }
// Antigravity 账户:其他错误仅记录日志,不标记 error可能是临时网络问题 // 可重试错误耗尽:仅记录日志,不标记 error可能是临时网络问题,下个周期继续重试
// 其他平台账户:重试失败后标记 error slog.Warn("token_refresh.retry_exhausted",
if account.Platform == PlatformAntigravity { "account_id", account.ID,
slog.Warn("token_refresh.retry_exhausted_antigravity", "platform", account.Platform,
"account_id", account.ID, "max_retries", s.cfg.MaxRetries,
"max_retries", s.cfg.MaxRetries, "error", lastErr,
"error", lastErr, )
)
} else {
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
slog.Error("token_refresh.set_error_status_failed",
"account_id", account.ID,
"error", err,
)
}
}
return lastErr return lastErr
} }

View File

@@ -14,10 +14,11 @@ import (
type tokenRefreshAccountRepo struct { type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini mockAccountRepoForGemini
updateCalls int updateCalls int
setErrorCalls int setErrorCalls int
lastAccount *Account clearTempCalls int
updateErr error lastAccount *Account
updateErr error
} }
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
@@ -31,6 +32,11 @@ func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorM
return nil return nil
} }
func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error {
r.clearTempCalls++
return nil
}
type tokenCacheInvalidatorStub struct { type tokenCacheInvalidatorStub struct {
calls int calls int
err error err error
@@ -41,6 +47,23 @@ func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account
return s.err return s.err
} }
type tempUnschedCacheStub struct {
deleteCalls int
}
func (s *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
return nil
}
func (s *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) {
return nil, nil
}
func (s *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error {
s.deleteCalls++
return nil
}
type tokenRefresherStub struct { type tokenRefresherStub struct {
credentials map[string]any credentials map[string]any
err error err error
@@ -70,7 +93,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{ account := &Account{
ID: 5, ID: 5,
Platform: PlatformGemini, Platform: PlatformGemini,
@@ -98,7 +121,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{ account := &Account{
ID: 6, ID: 6,
Platform: PlatformGemini, Platform: PlatformGemini,
@@ -124,7 +147,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{ account := &Account{
ID: 7, ID: 7,
Platform: PlatformGemini, Platform: PlatformGemini,
@@ -151,7 +174,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{ account := &Account{
ID: 8, ID: 8,
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
@@ -179,7 +202,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{ account := &Account{
ID: 9, ID: 9,
Platform: PlatformGemini, Platform: PlatformGemini,
@@ -207,7 +230,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{ account := &Account{
ID: 10, ID: 10,
Platform: PlatformOpenAI, // OpenAI OAuth 账户 Platform: PlatformOpenAI, // OpenAI OAuth 账户
@@ -235,7 +258,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{ account := &Account{
ID: 11, ID: 11,
Platform: PlatformGemini, Platform: PlatformGemini,
@@ -254,7 +277,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效 require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效
} }
// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况 // TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试可重试错误耗尽不标记 error
func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{} repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{} invalidator := &tokenCacheInvalidatorStub{}
@@ -264,7 +287,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{ account := &Account{
ID: 12, ID: 12,
Platform: PlatformGemini, Platform: PlatformGemini,
@@ -278,7 +301,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
require.Error(t, err) require.Error(t, err)
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新 require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效 require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态 require.Equal(t, 0, repo.setErrorCalls) // 可重试错误耗尽不标记 error下个周期继续重试
} }
// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态 // TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态
@@ -291,7 +314,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{ account := &Account{
ID: 13, ID: 13,
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
@@ -318,7 +341,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
RetryBackoffSeconds: 0, RetryBackoffSeconds: 0,
}, },
} }
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{ account := &Account{
ID: 14, ID: 14,
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
@@ -335,6 +358,77 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态 require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态
} }
// TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable 测试刷新成功后清除临时不可调度DB + Redis
func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
tempCache := &tempUnschedCacheStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
until := time.Now().Add(10 * time.Minute)
account := &Account{
ID: 15,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
TempUnschedulableUntil: &until,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除
}
// TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms 测试所有平台不可重试错误都 SetError
func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "anthropic", platform: PlatformAnthropic},
{name: "openai", platform: PlatformOpenAI},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 3,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 16,
Platform: tt.platform,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("invalid_grant: token revoked"),
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError
})
}
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断 // TestIsNonRetryableRefreshError 测试不可重试错误判断
func TestIsNonRetryableRefreshError(t *testing.T) { func TestIsNonRetryableRefreshError(t *testing.T) {
tests := []struct { tests := []struct {

View File

@@ -315,6 +315,15 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
return stats, nil return stats, nil
} }
// GetAPIKeyModelStats returns per-model usage stats for a specific API Key.
func (s *UsageService) GetAPIKeyModelStats(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, 0, apiKeyID, 0, 0, nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("get api key model stats: %w", err)
}
return stats, nil
}
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys. // GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime) stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)

View File

@@ -96,6 +96,18 @@ func (m *mockBillingCache) UpdateSubscriptionUsage(context.Context, int64, int64
func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error { func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error {
return nil return nil
} }
func (m *mockBillingCache) GetAPIKeyRateLimit(context.Context, int64) (*APIKeyRateLimitCacheData, error) {
return nil, nil
}
func (m *mockBillingCache) SetAPIKeyRateLimit(context.Context, int64, *APIKeyRateLimitCacheData) error {
return nil
}
func (m *mockBillingCache) UpdateAPIKeyRateLimitUsage(context.Context, int64, float64) error {
return nil
}
func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) error {
return nil
}
// --- 测试 --- // --- 测试 ---

View File

@@ -48,8 +48,9 @@ func ProvideTokenRefreshService(
cacheInvalidator TokenCacheInvalidator, cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache, schedulerCache SchedulerCache,
cfg *config.Config, cfg *config.Config,
tempUnschedCache TempUnschedCache,
) *TokenRefreshService { ) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg) svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
svc.SetSoraAccountRepo(soraAccountRepo) svc.SetSoraAccountRepo(soraAccountRepo)
svc.Start() svc.Start()

View File

@@ -19,11 +19,47 @@ $$;
CREATE INDEX IF NOT EXISTS idx_usage_logs_request_type_created_at CREATE INDEX IF NOT EXISTS idx_usage_logs_request_type_created_at
ON usage_logs (request_type, created_at); ON usage_logs (request_type, created_at);
-- Backfill from legacy fields. openai_ws_mode has higher priority than stream. -- Backfill from legacy fields in bounded batches.
UPDATE usage_logs -- Why bounded:
SET request_type = CASE -- 1) Full-table UPDATE on large usage_logs can block startup for a long time.
WHEN openai_ws_mode = TRUE THEN 3 -- 2) request_type=0 rows remain query-compatible via legacy fallback logic
WHEN stream = TRUE THEN 2 -- (stream/openai_ws_mode) in repository filters.
ELSE 1 -- 3) Subsequent writes will use explicit request_type and gradually dilute
-- historical unknown rows.
--
-- openai_ws_mode has higher priority than stream.
DO $$
DECLARE
v_rows INTEGER := 0;
v_total_rows INTEGER := 0;
v_batch_size INTEGER := 5000;
v_started_at TIMESTAMPTZ := clock_timestamp();
v_max_duration INTERVAL := INTERVAL '8 seconds';
BEGIN
LOOP
WITH batch AS (
SELECT id
FROM usage_logs
WHERE request_type = 0
ORDER BY id
LIMIT v_batch_size
)
UPDATE usage_logs ul
SET request_type = CASE
WHEN ul.openai_ws_mode = TRUE THEN 3
WHEN ul.stream = TRUE THEN 2
ELSE 1
END
FROM batch
WHERE ul.id = batch.id;
GET DIAGNOSTICS v_rows = ROW_COUNT;
EXIT WHEN v_rows = 0;
v_total_rows := v_total_rows + v_rows;
EXIT WHEN clock_timestamp() - v_started_at >= v_max_duration;
END LOOP;
RAISE NOTICE 'usage_logs.request_type startup backfill rows=%', v_total_rows;
END END
WHERE request_type = 0; $$;

View File

@@ -0,0 +1,15 @@
-- Add rate limit fields to api_keys table
-- Rate limit configuration (0 = unlimited)
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS rate_limit_5h decimal(20,8) NOT NULL DEFAULT 0;
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS rate_limit_1d decimal(20,8) NOT NULL DEFAULT 0;
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS rate_limit_7d decimal(20,8) NOT NULL DEFAULT 0;
-- Rate limit usage tracking
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS usage_5h decimal(20,8) NOT NULL DEFAULT 0;
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS usage_1d decimal(20,8) NOT NULL DEFAULT 0;
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS usage_7d decimal(20,8) NOT NULL DEFAULT 0;
-- Window start times (nullable)
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS window_5h_start timestamptz;
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS window_1d_start timestamptz;
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS window_7d_start timestamptz;

View File

@@ -78,6 +78,9 @@ export interface SystemSettings {
// Claude Code version check // Claude Code version check
min_claude_code_version: string min_claude_code_version: string
// 分组隔离
allow_ungrouped_key_scheduling: boolean
} }
export interface UpdateSettingsRequest { export interface UpdateSettingsRequest {
@@ -128,6 +131,7 @@ export interface UpdateSettingsRequest {
ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string
ops_metrics_interval_seconds?: number ops_metrics_interval_seconds?: number
min_claude_code_version?: string min_claude_code_version?: string
allow_ungrouped_key_scheduling?: boolean
} }
/** /**

View File

@@ -46,6 +46,7 @@ export async function getById(id: number): Promise<ApiKey> {
* @param ipBlacklist - Optional IP blacklist * @param ipBlacklist - Optional IP blacklist
* @param quota - Optional quota limit in USD (0 = unlimited) * @param quota - Optional quota limit in USD (0 = unlimited)
* @param expiresInDays - Optional days until expiry (undefined = never expires) * @param expiresInDays - Optional days until expiry (undefined = never expires)
* @param rateLimitData - Optional rate limit fields
* @returns Created API key * @returns Created API key
*/ */
export async function create( export async function create(
@@ -55,7 +56,8 @@ export async function create(
ipWhitelist?: string[], ipWhitelist?: string[],
ipBlacklist?: string[], ipBlacklist?: string[],
quota?: number, quota?: number,
expiresInDays?: number expiresInDays?: number,
rateLimitData?: { rate_limit_5h?: number; rate_limit_1d?: number; rate_limit_7d?: number }
): Promise<ApiKey> { ): Promise<ApiKey> {
const payload: CreateApiKeyRequest = { name } const payload: CreateApiKeyRequest = { name }
if (groupId !== undefined) { if (groupId !== undefined) {
@@ -76,6 +78,15 @@ export async function create(
if (expiresInDays !== undefined && expiresInDays > 0) { if (expiresInDays !== undefined && expiresInDays > 0) {
payload.expires_in_days = expiresInDays payload.expires_in_days = expiresInDays
} }
if (rateLimitData?.rate_limit_5h && rateLimitData.rate_limit_5h > 0) {
payload.rate_limit_5h = rateLimitData.rate_limit_5h
}
if (rateLimitData?.rate_limit_1d && rateLimitData.rate_limit_1d > 0) {
payload.rate_limit_1d = rateLimitData.rate_limit_1d
}
if (rateLimitData?.rate_limit_7d && rateLimitData.rate_limit_7d > 0) {
payload.rate_limit_7d = rateLimitData.rate_limit_7d
}
const { data } = await apiClient.post<ApiKey>('/keys', payload) const { data } = await apiClient.post<ApiKey>('/keys', payload)
return data return data

View File

@@ -560,6 +560,19 @@ export default {
resetQuotaConfirmMessage: 'Are you sure you want to reset the used quota (${used}) for key "{name}" to 0? This action cannot be undone.', resetQuotaConfirmMessage: 'Are you sure you want to reset the used quota (${used}) for key "{name}" to 0? This action cannot be undone.',
quotaResetSuccess: 'Quota reset successfully', quotaResetSuccess: 'Quota reset successfully',
failedToResetQuota: 'Failed to reset quota', failedToResetQuota: 'Failed to reset quota',
rateLimitColumn: 'Rate Limit',
rateLimitSection: 'Rate Limit',
resetUsage: 'Reset',
rateLimit5h: '5-Hour Limit (USD)',
rateLimit1d: 'Daily Limit (USD)',
rateLimit7d: '7-Day Limit (USD)',
rateLimitHint: 'Set the maximum spending for this key within each time window. 0 = unlimited.',
rateLimitUsage: 'Rate Limit Usage',
resetRateLimitUsage: 'Reset Rate Limit Usage',
resetRateLimitTitle: 'Confirm Reset Rate Limit',
resetRateLimitConfirmMessage: 'Are you sure you want to reset the rate limit usage for key "{name}"? All time window usage will be reset to zero. This action cannot be undone.',
rateLimitResetSuccess: 'Rate limit usage reset successfully',
failedToResetRateLimit: 'Failed to reset rate limit usage',
expiration: 'Expiration', expiration: 'Expiration',
expiresInDays: '{days} days', expiresInDays: '{days} days',
extendDays: '+{days} days', extendDays: '+{days} days',
@@ -3578,6 +3591,12 @@ export default {
minVersionHint: minVersionHint:
'Reject Claude Code clients below this version (semver format). Leave empty to disable version check.' 'Reject Claude Code clients below this version (semver format). Leave empty to disable version check.'
}, },
scheduling: {
title: 'Gateway Scheduling Settings',
description: 'Control API Key scheduling behavior',
allowUngroupedKey: 'Allow Ungrouped Key Scheduling',
allowUngroupedKeyHint: 'When disabled, API Keys not assigned to any group cannot make requests (403 Forbidden). Keep disabled to ensure all Keys belong to a specific group.'
},
site: { site: {
title: 'Site Settings', title: 'Site Settings',
description: 'Customize site branding', description: 'Customize site branding',

View File

@@ -566,6 +566,19 @@ export default {
resetQuotaConfirmMessage: '确定要将密钥 "{name}" 的已用额度(${used})重置为 0 吗?此操作不可撤销。', resetQuotaConfirmMessage: '确定要将密钥 "{name}" 的已用额度(${used})重置为 0 吗?此操作不可撤销。',
quotaResetSuccess: '额度重置成功', quotaResetSuccess: '额度重置成功',
failedToResetQuota: '重置额度失败', failedToResetQuota: '重置额度失败',
rateLimitColumn: '速率限制',
rateLimitSection: '速率限制',
resetUsage: '重置',
rateLimit5h: '5小时限额 (USD)',
rateLimit1d: '日限额 (USD)',
rateLimit7d: '7天限额 (USD)',
rateLimitHint: '设置此密钥在指定时间窗口内的最大消费额。0 = 无限制。',
rateLimitUsage: '速率限制用量',
resetRateLimitUsage: '重置速率限制用量',
resetRateLimitTitle: '确认重置速率限制',
resetRateLimitConfirmMessage: '确定要重置密钥 "{name}" 的速率限制用量吗?所有时间窗口的已用额度将归零。此操作不可撤销。',
rateLimitResetSuccess: '速率限制已重置',
failedToResetRateLimit: '重置速率限制失败',
expiration: '密钥有效期', expiration: '密钥有效期',
expiresInDays: '{days} 天', expiresInDays: '{days} 天',
extendDays: '+{days} 天', extendDays: '+{days} 天',
@@ -3746,6 +3759,12 @@ export default {
minVersionPlaceholder: '例如 2.1.63', minVersionPlaceholder: '例如 2.1.63',
minVersionHint: '拒绝低于此版本的 Claude Code 客户端请求semver 格式)。留空则不检查版本。' minVersionHint: '拒绝低于此版本的 Claude Code 客户端请求semver 格式)。留空则不检查版本。'
}, },
scheduling: {
title: '网关调度设置',
description: '控制 API Key 的调度行为',
allowUngroupedKey: '允许未分组 Key 调度',
allowUngroupedKeyHint: '关闭后,未分配到任何分组的 API Key 将无法发起请求(返回 403。建议保持关闭以确保所有 Key 都归属明确的分组。'
},
site: { site: {
title: '站点设置', title: '站点设置',
description: '自定义站点品牌', description: '自定义站点品牌',

View File

@@ -421,6 +421,15 @@ export interface ApiKey {
created_at: string created_at: string
updated_at: string updated_at: string
group?: Group group?: Group
rate_limit_5h: number
rate_limit_1d: number
rate_limit_7d: number
usage_5h: number
usage_1d: number
usage_7d: number
window_5h_start: string | null
window_1d_start: string | null
window_7d_start: string | null
} }
export interface CreateApiKeyRequest { export interface CreateApiKeyRequest {
@@ -431,6 +440,9 @@ export interface CreateApiKeyRequest {
ip_blacklist?: string[] ip_blacklist?: string[]
quota?: number // Quota limit in USD (0 = unlimited) quota?: number // Quota limit in USD (0 = unlimited)
expires_in_days?: number // Days until expiry (null = never expires) expires_in_days?: number // Days until expiry (null = never expires)
rate_limit_5h?: number
rate_limit_1d?: number
rate_limit_7d?: number
} }
export interface UpdateApiKeyRequest { export interface UpdateApiKeyRequest {
@@ -442,6 +454,10 @@ export interface UpdateApiKeyRequest {
quota?: number // Quota limit in USD (null = no change, 0 = unlimited) quota?: number // Quota limit in USD (null = no change, 0 = unlimited)
expires_at?: string | null // Expiration time (null = no change) expires_at?: string | null // Expiration time (null = no change)
reset_quota?: boolean // Reset quota_used to 0 reset_quota?: boolean // Reset quota_used to 0
rate_limit_5h?: number
rate_limit_1d?: number
rate_limit_7d?: number
reset_rate_limit_usage?: boolean
} }
export interface CreateGroupRequest { export interface CreateGroupRequest {

View File

@@ -246,7 +246,10 @@
{{ t('admin.dashboard.recentUsage') }} (Top 12) {{ t('admin.dashboard.recentUsage') }} (Top 12)
</h3> </h3>
<div class="h-64"> <div class="h-64">
<Line v-if="userTrendChartData" :data="userTrendChartData" :options="lineOptions" /> <div v-if="userTrendLoading" class="flex h-full items-center justify-center">
<LoadingSpinner size="md" />
</div>
<Line v-else-if="userTrendChartData" :data="userTrendChartData" :options="lineOptions" />
<div <div
v-else v-else
class="flex h-full items-center justify-center text-sm text-gray-500 dark:text-gray-400" class="flex h-full items-center justify-center text-sm text-gray-500 dark:text-gray-400"
@@ -306,11 +309,13 @@ const appStore = useAppStore()
const stats = ref<DashboardStats | null>(null) const stats = ref<DashboardStats | null>(null)
const loading = ref(false) const loading = ref(false)
const chartsLoading = ref(false) const chartsLoading = ref(false)
const userTrendLoading = ref(false)
// Chart data // Chart data
const trendData = ref<TrendDataPoint[]>([]) const trendData = ref<TrendDataPoint[]>([])
const modelStats = ref<ModelStat[]>([]) const modelStats = ref<ModelStat[]>([])
const userTrend = ref<UserUsageTrendPoint[]>([]) const userTrend = ref<UserUsageTrendPoint[]>([])
let chartLoadSeq = 0
// Helper function to format date in local timezone // Helper function to format date in local timezone
const formatLocalDate = (date: Date): string => { const formatLocalDate = (date: Date): string => {
@@ -531,7 +536,9 @@ const loadDashboardStats = async () => {
} }
const loadChartData = async () => { const loadChartData = async () => {
const currentSeq = ++chartLoadSeq
chartsLoading.value = true chartsLoading.value = true
userTrendLoading.value = true
try { try {
const params = { const params = {
start_date: startDate.value, start_date: startDate.value,
@@ -539,20 +546,39 @@ const loadChartData = async () => {
granularity: granularity.value granularity: granularity.value
} }
const [trendResponse, modelResponse, userResponse] = await Promise.all([ const [trendResponse, modelResponse] = await Promise.all([
adminAPI.dashboard.getUsageTrend(params), adminAPI.dashboard.getUsageTrend(params),
adminAPI.dashboard.getModelStats({ start_date: startDate.value, end_date: endDate.value }), adminAPI.dashboard.getModelStats({ start_date: startDate.value, end_date: endDate.value })
adminAPI.dashboard.getUserUsageTrend({ ...params, limit: 12 })
]) ])
if (currentSeq !== chartLoadSeq) return
trendData.value = trendResponse.trend || [] trendData.value = trendResponse.trend || []
modelStats.value = modelResponse.models || [] modelStats.value = modelResponse.models || []
userTrend.value = userResponse.trend || []
} catch (error) { } catch (error) {
if (currentSeq !== chartLoadSeq) return
console.error('Error loading chart data:', error) console.error('Error loading chart data:', error)
} finally { } finally {
if (currentSeq !== chartLoadSeq) return
chartsLoading.value = false chartsLoading.value = false
} }
try {
const params = {
start_date: startDate.value,
end_date: endDate.value,
granularity: granularity.value,
limit: 12
}
const userResponse = await adminAPI.dashboard.getUserUsageTrend(params)
if (currentSeq !== chartLoadSeq) return
userTrend.value = userResponse.trend || []
} catch (error) {
if (currentSeq !== chartLoadSeq) return
console.error('Error loading user trend:', error)
} finally {
if (currentSeq !== chartLoadSeq) return
userTrendLoading.value = false
}
} }
onMounted(() => { onMounted(() => {

View File

@@ -737,6 +737,34 @@
</div> </div>
</div> </div>
<!-- Gateway Scheduling Settings -->
<div class="card">
<div class="border-b border-gray-100 px-6 py-4 dark:border-dark-700">
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
{{ t('admin.settings.scheduling.title') }}
</h2>
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.settings.scheduling.description') }}
</p>
</div>
<div class="p-6">
<div class="flex items-center justify-between">
<div>
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.settings.scheduling.allowUngroupedKey') }}
</label>
<p class="mt-0.5 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.settings.scheduling.allowUngroupedKeyHint') }}
</p>
</div>
<label class="toggle">
<input v-model="form.allow_ungrouped_key_scheduling" type="checkbox" />
<span class="toggle-slider"></span>
</label>
</div>
</div>
</div>
<!-- Site Settings --> <!-- Site Settings -->
<div class="card"> <div class="card">
<div class="border-b border-gray-100 px-6 py-4 dark:border-dark-700"> <div class="border-b border-gray-100 px-6 py-4 dark:border-dark-700">
@@ -1438,7 +1466,9 @@ const form = reactive<SettingsForm>({
ops_query_mode_default: 'auto', ops_query_mode_default: 'auto',
ops_metrics_interval_seconds: 60, ops_metrics_interval_seconds: 60,
// Claude Code version check // Claude Code version check
min_claude_code_version: '' min_claude_code_version: '',
// 分组隔离
allow_ungrouped_key_scheduling: false
}) })
const defaultSubscriptionGroupOptions = computed<DefaultSubscriptionGroupOption[]>(() => const defaultSubscriptionGroupOptions = computed<DefaultSubscriptionGroupOption[]>(() =>
@@ -1623,7 +1653,8 @@ async function saveSettings() {
fallback_model_antigravity: form.fallback_model_antigravity, fallback_model_antigravity: form.fallback_model_antigravity,
enable_identity_patch: form.enable_identity_patch, enable_identity_patch: form.enable_identity_patch,
identity_patch_prompt: form.identity_patch_prompt, identity_patch_prompt: form.identity_patch_prompt,
min_claude_code_version: form.min_claude_code_version min_claude_code_version: form.min_claude_code_version,
allow_ungrouped_key_scheduling: form.allow_ungrouped_key_scheduling
} }
const updated = await adminAPI.settings.updateSettings(payload) const updated = await adminAPI.settings.updateSettings(payload)
Object.assign(form, updated) Object.assign(form, updated)

View File

@@ -137,6 +137,97 @@
</div> </div>
</template> </template>
<template #cell-rate_limit="{ row }">
<div v-if="row.rate_limit_5h > 0 || row.rate_limit_1d > 0 || row.rate_limit_7d > 0" class="space-y-1.5 min-w-[140px]">
<!-- 5h window -->
<div v-if="row.rate_limit_5h > 0">
<div class="flex items-center justify-between text-xs">
<span class="text-gray-500 dark:text-gray-400">5h</span>
<span :class="[
'font-medium tabular-nums',
row.usage_5h >= row.rate_limit_5h ? 'text-red-500' :
row.usage_5h >= row.rate_limit_5h * 0.8 ? 'text-yellow-500' :
'text-gray-700 dark:text-gray-300'
]">
${{ row.usage_5h?.toFixed(2) || '0.00' }}/${{ row.rate_limit_5h?.toFixed(2) }}
</span>
</div>
<div class="h-1 w-full overflow-hidden rounded-full bg-gray-200 dark:bg-dark-600">
<div
:class="[
'h-full rounded-full transition-all',
row.usage_5h >= row.rate_limit_5h ? 'bg-red-500' :
row.usage_5h >= row.rate_limit_5h * 0.8 ? 'bg-yellow-500' :
'bg-emerald-500'
]"
:style="{ width: Math.min((row.usage_5h / row.rate_limit_5h) * 100, 100) + '%' }"
/>
</div>
</div>
<!-- 1d window -->
<div v-if="row.rate_limit_1d > 0">
<div class="flex items-center justify-between text-xs">
<span class="text-gray-500 dark:text-gray-400">1d</span>
<span :class="[
'font-medium tabular-nums',
row.usage_1d >= row.rate_limit_1d ? 'text-red-500' :
row.usage_1d >= row.rate_limit_1d * 0.8 ? 'text-yellow-500' :
'text-gray-700 dark:text-gray-300'
]">
${{ row.usage_1d?.toFixed(2) || '0.00' }}/${{ row.rate_limit_1d?.toFixed(2) }}
</span>
</div>
<div class="h-1 w-full overflow-hidden rounded-full bg-gray-200 dark:bg-dark-600">
<div
:class="[
'h-full rounded-full transition-all',
row.usage_1d >= row.rate_limit_1d ? 'bg-red-500' :
row.usage_1d >= row.rate_limit_1d * 0.8 ? 'bg-yellow-500' :
'bg-emerald-500'
]"
:style="{ width: Math.min((row.usage_1d / row.rate_limit_1d) * 100, 100) + '%' }"
/>
</div>
</div>
<!-- 7d window -->
<div v-if="row.rate_limit_7d > 0">
<div class="flex items-center justify-between text-xs">
<span class="text-gray-500 dark:text-gray-400">7d</span>
<span :class="[
'font-medium tabular-nums',
row.usage_7d >= row.rate_limit_7d ? 'text-red-500' :
row.usage_7d >= row.rate_limit_7d * 0.8 ? 'text-yellow-500' :
'text-gray-700 dark:text-gray-300'
]">
${{ row.usage_7d?.toFixed(2) || '0.00' }}/${{ row.rate_limit_7d?.toFixed(2) }}
</span>
</div>
<div class="h-1 w-full overflow-hidden rounded-full bg-gray-200 dark:bg-dark-600">
<div
:class="[
'h-full rounded-full transition-all',
row.usage_7d >= row.rate_limit_7d ? 'bg-red-500' :
row.usage_7d >= row.rate_limit_7d * 0.8 ? 'bg-yellow-500' :
'bg-emerald-500'
]"
:style="{ width: Math.min((row.usage_7d / row.rate_limit_7d) * 100, 100) + '%' }"
/>
</div>
</div>
<!-- Reset button -->
<button
v-if="row.usage_5h > 0 || row.usage_1d > 0 || row.usage_7d > 0"
@click.stop="confirmResetRateLimitFromTable(row)"
class="mt-0.5 inline-flex items-center gap-1 rounded px-1.5 py-0.5 text-xs text-gray-500 transition-colors hover:bg-gray-100 hover:text-primary-600 dark:hover:bg-dark-700 dark:hover:text-primary-400"
:title="t('keys.resetRateLimitUsage')"
>
<Icon name="refresh" size="xs" />
{{ t('keys.resetUsage') }}
</button>
</div>
<span v-else class="text-sm text-gray-400 dark:text-dark-500">-</span>
</template>
<template #cell-expires_at="{ value }"> <template #cell-expires_at="{ value }">
<span v-if="value" :class="[ <span v-if="value" :class="[
'text-sm', 'text-sm',
@@ -452,6 +543,180 @@
</div> </div>
</div> </div>
<!-- Rate Limit Section -->
<div class="space-y-3">
<div class="flex items-center justify-between">
<label class="input-label mb-0">{{ t('keys.rateLimitSection') }}</label>
<button
type="button"
@click="formData.enable_rate_limit = !formData.enable_rate_limit"
:class="[
'relative inline-flex h-5 w-9 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none',
formData.enable_rate_limit ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-4 w-4 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
formData.enable_rate_limit ? 'translate-x-4' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="formData.enable_rate_limit" class="space-y-4 pt-2">
<p class="input-hint -mt-2">{{ t('keys.rateLimitHint') }}</p>
<!-- 5-Hour Limit -->
<div>
<label class="input-label">{{ t('keys.rateLimit5h') }}</label>
<div class="relative">
<span class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-500">$</span>
<input
v-model.number="formData.rate_limit_5h"
type="number"
step="0.01"
min="0"
class="input pl-7"
:placeholder="'0'"
/>
</div>
<!-- Usage info (edit mode only) -->
<div v-if="showEditModal && selectedKey && selectedKey.rate_limit_5h > 0" class="mt-2">
<div class="flex items-center gap-2">
<div class="flex-1 rounded-lg bg-gray-100 px-3 py-2 dark:bg-dark-700 text-sm">
<span :class="[
'font-medium',
selectedKey.usage_5h >= selectedKey.rate_limit_5h ? 'text-red-500' :
selectedKey.usage_5h >= selectedKey.rate_limit_5h * 0.8 ? 'text-yellow-500' :
'text-gray-900 dark:text-white'
]">
${{ selectedKey.usage_5h?.toFixed(4) || '0.0000' }}
</span>
<span class="mx-2 text-gray-400">/</span>
<span class="text-gray-500 dark:text-gray-400">
${{ selectedKey.rate_limit_5h?.toFixed(2) || '0.00' }}
</span>
</div>
</div>
<div class="mt-1 h-1.5 w-full overflow-hidden rounded-full bg-gray-200 dark:bg-dark-600">
<div
:class="[
'h-full rounded-full transition-all',
selectedKey.usage_5h >= selectedKey.rate_limit_5h ? 'bg-red-500' :
selectedKey.usage_5h >= selectedKey.rate_limit_5h * 0.8 ? 'bg-yellow-500' :
'bg-green-500'
]"
:style="{ width: Math.min((selectedKey.usage_5h / selectedKey.rate_limit_5h) * 100, 100) + '%' }"
/>
</div>
</div>
</div>
<!-- Daily Limit -->
<div>
<label class="input-label">{{ t('keys.rateLimit1d') }}</label>
<div class="relative">
<span class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-500">$</span>
<input
v-model.number="formData.rate_limit_1d"
type="number"
step="0.01"
min="0"
class="input pl-7"
:placeholder="'0'"
/>
</div>
<!-- Usage info (edit mode only) -->
<div v-if="showEditModal && selectedKey && selectedKey.rate_limit_1d > 0" class="mt-2">
<div class="flex items-center gap-2">
<div class="flex-1 rounded-lg bg-gray-100 px-3 py-2 dark:bg-dark-700 text-sm">
<span :class="[
'font-medium',
selectedKey.usage_1d >= selectedKey.rate_limit_1d ? 'text-red-500' :
selectedKey.usage_1d >= selectedKey.rate_limit_1d * 0.8 ? 'text-yellow-500' :
'text-gray-900 dark:text-white'
]">
${{ selectedKey.usage_1d?.toFixed(4) || '0.0000' }}
</span>
<span class="mx-2 text-gray-400">/</span>
<span class="text-gray-500 dark:text-gray-400">
${{ selectedKey.rate_limit_1d?.toFixed(2) || '0.00' }}
</span>
</div>
</div>
<div class="mt-1 h-1.5 w-full overflow-hidden rounded-full bg-gray-200 dark:bg-dark-600">
<div
:class="[
'h-full rounded-full transition-all',
selectedKey.usage_1d >= selectedKey.rate_limit_1d ? 'bg-red-500' :
selectedKey.usage_1d >= selectedKey.rate_limit_1d * 0.8 ? 'bg-yellow-500' :
'bg-green-500'
]"
:style="{ width: Math.min((selectedKey.usage_1d / selectedKey.rate_limit_1d) * 100, 100) + '%' }"
/>
</div>
</div>
</div>
<!-- 7-Day Limit -->
<div>
<label class="input-label">{{ t('keys.rateLimit7d') }}</label>
<div class="relative">
<span class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-500">$</span>
<input
v-model.number="formData.rate_limit_7d"
type="number"
step="0.01"
min="0"
class="input pl-7"
:placeholder="'0'"
/>
</div>
<!-- Usage info (edit mode only) -->
<div v-if="showEditModal && selectedKey && selectedKey.rate_limit_7d > 0" class="mt-2">
<div class="flex items-center gap-2">
<div class="flex-1 rounded-lg bg-gray-100 px-3 py-2 dark:bg-dark-700 text-sm">
<span :class="[
'font-medium',
selectedKey.usage_7d >= selectedKey.rate_limit_7d ? 'text-red-500' :
selectedKey.usage_7d >= selectedKey.rate_limit_7d * 0.8 ? 'text-yellow-500' :
'text-gray-900 dark:text-white'
]">
${{ selectedKey.usage_7d?.toFixed(4) || '0.0000' }}
</span>
<span class="mx-2 text-gray-400">/</span>
<span class="text-gray-500 dark:text-gray-400">
${{ selectedKey.rate_limit_7d?.toFixed(2) || '0.00' }}
</span>
</div>
</div>
<div class="mt-1 h-1.5 w-full overflow-hidden rounded-full bg-gray-200 dark:bg-dark-600">
<div
:class="[
'h-full rounded-full transition-all',
selectedKey.usage_7d >= selectedKey.rate_limit_7d ? 'bg-red-500' :
selectedKey.usage_7d >= selectedKey.rate_limit_7d * 0.8 ? 'bg-yellow-500' :
'bg-green-500'
]"
:style="{ width: Math.min((selectedKey.usage_7d / selectedKey.rate_limit_7d) * 100, 100) + '%' }"
/>
</div>
</div>
</div>
<!-- Reset Rate Limit button (edit mode only) -->
<div v-if="showEditModal && selectedKey && (selectedKey.rate_limit_5h > 0 || selectedKey.rate_limit_1d > 0 || selectedKey.rate_limit_7d > 0)">
<button
type="button"
@click="confirmResetRateLimit"
class="btn btn-secondary text-sm"
>
{{ t('keys.resetRateLimitUsage') }}
</button>
</div>
</div>
</div>
<!-- Expiration Section --> <!-- Expiration Section -->
<div class="space-y-3"> <div class="space-y-3">
<div class="flex items-center justify-between"> <div class="flex items-center justify-between">
@@ -593,6 +858,18 @@
@cancel="showResetQuotaDialog = false" @cancel="showResetQuotaDialog = false"
/> />
<!-- Reset Rate Limit Confirmation Dialog -->
<ConfirmDialog
:show="showResetRateLimitDialog"
:title="t('keys.resetRateLimitTitle')"
:message="t('keys.resetRateLimitConfirmMessage', { name: selectedKey?.name })"
:confirm-text="t('keys.reset')"
:cancel-text="t('common.cancel')"
:danger="true"
@confirm="resetRateLimitUsage"
@cancel="showResetRateLimitDialog = false"
/>
<!-- Use Key Modal --> <!-- Use Key Modal -->
<UseKeyModal <UseKeyModal
:show="showUseKeyModal" :show="showUseKeyModal"
@@ -743,6 +1020,7 @@ const columns = computed<Column[]>(() => [
{ key: 'key', label: t('keys.apiKey'), sortable: false }, { key: 'key', label: t('keys.apiKey'), sortable: false },
{ key: 'group', label: t('keys.group'), sortable: false }, { key: 'group', label: t('keys.group'), sortable: false },
{ key: 'usage', label: t('keys.usage'), sortable: false }, { key: 'usage', label: t('keys.usage'), sortable: false },
{ key: 'rate_limit', label: t('keys.rateLimitColumn'), sortable: false },
{ key: 'expires_at', label: t('keys.expiresAt'), sortable: true }, { key: 'expires_at', label: t('keys.expiresAt'), sortable: true },
{ key: 'status', label: t('common.status'), sortable: true }, { key: 'status', label: t('common.status'), sortable: true },
{ key: 'last_used_at', label: t('keys.lastUsedAt'), sortable: true }, { key: 'last_used_at', label: t('keys.lastUsedAt'), sortable: true },
@@ -768,6 +1046,7 @@ const showCreateModal = ref(false)
const showEditModal = ref(false) const showEditModal = ref(false)
const showDeleteDialog = ref(false) const showDeleteDialog = ref(false)
const showResetQuotaDialog = ref(false) const showResetQuotaDialog = ref(false)
const showResetRateLimitDialog = ref(false)
const showUseKeyModal = ref(false) const showUseKeyModal = ref(false)
const showCcsClientSelect = ref(false) const showCcsClientSelect = ref(false)
const pendingCcsRow = ref<ApiKey | null>(null) const pendingCcsRow = ref<ApiKey | null>(null)
@@ -806,6 +1085,11 @@ const formData = ref({
// Quota settings (empty = unlimited) // Quota settings (empty = unlimited)
enable_quota: false, enable_quota: false,
quota: null as number | null, quota: null as number | null,
// Rate limit settings
enable_rate_limit: false,
rate_limit_5h: null as number | null,
rate_limit_1d: null as number | null,
rate_limit_7d: null as number | null,
enable_expiration: false, enable_expiration: false,
expiration_preset: '30' as '7' | '30' | '90' | 'custom', expiration_preset: '30' as '7' | '30' | '90' | 'custom',
expiration_date: '' expiration_date: ''
@@ -966,6 +1250,10 @@ const editKey = (key: ApiKey) => {
ip_blacklist: (key.ip_blacklist || []).join('\n'), ip_blacklist: (key.ip_blacklist || []).join('\n'),
enable_quota: key.quota > 0, enable_quota: key.quota > 0,
quota: key.quota > 0 ? key.quota : null, quota: key.quota > 0 ? key.quota : null,
enable_rate_limit: (key.rate_limit_5h > 0) || (key.rate_limit_1d > 0) || (key.rate_limit_7d > 0),
rate_limit_5h: key.rate_limit_5h || null,
rate_limit_1d: key.rate_limit_1d || null,
rate_limit_7d: key.rate_limit_7d || null,
enable_expiration: hasExpiration, enable_expiration: hasExpiration,
expiration_preset: 'custom', expiration_preset: 'custom',
expiration_date: key.expires_at ? formatDateTimeLocal(key.expires_at) : '' expiration_date: key.expires_at ? formatDateTimeLocal(key.expires_at) : ''
@@ -1078,6 +1366,13 @@ const handleSubmit = async () => {
expiresAt = '' expiresAt = ''
} }
// Calculate rate limit values (send 0 when toggle is off)
const rateLimitData = formData.value.enable_rate_limit ? {
rate_limit_5h: formData.value.rate_limit_5h && formData.value.rate_limit_5h > 0 ? formData.value.rate_limit_5h : 0,
rate_limit_1d: formData.value.rate_limit_1d && formData.value.rate_limit_1d > 0 ? formData.value.rate_limit_1d : 0,
rate_limit_7d: formData.value.rate_limit_7d && formData.value.rate_limit_7d > 0 ? formData.value.rate_limit_7d : 0,
} : { rate_limit_5h: 0, rate_limit_1d: 0, rate_limit_7d: 0 }
submitting.value = true submitting.value = true
try { try {
if (showEditModal.value && selectedKey.value) { if (showEditModal.value && selectedKey.value) {
@@ -1088,7 +1383,10 @@ const handleSubmit = async () => {
ip_whitelist: ipWhitelist, ip_whitelist: ipWhitelist,
ip_blacklist: ipBlacklist, ip_blacklist: ipBlacklist,
quota: quota, quota: quota,
expires_at: expiresAt expires_at: expiresAt,
rate_limit_5h: rateLimitData.rate_limit_5h,
rate_limit_1d: rateLimitData.rate_limit_1d,
rate_limit_7d: rateLimitData.rate_limit_7d,
}) })
appStore.showSuccess(t('keys.keyUpdatedSuccess')) appStore.showSuccess(t('keys.keyUpdatedSuccess'))
} else { } else {
@@ -1100,7 +1398,8 @@ const handleSubmit = async () => {
ipWhitelist, ipWhitelist,
ipBlacklist, ipBlacklist,
quota, quota,
expiresInDays expiresInDays,
rateLimitData
) )
appStore.showSuccess(t('keys.keyCreatedSuccess')) appStore.showSuccess(t('keys.keyCreatedSuccess'))
// Only advance tour if active, on submit step, and creation succeeded // Only advance tour if active, on submit step, and creation succeeded
@@ -1154,6 +1453,10 @@ const closeModals = () => {
ip_blacklist: '', ip_blacklist: '',
enable_quota: false, enable_quota: false,
quota: null, quota: null,
enable_rate_limit: false,
rate_limit_5h: null,
rate_limit_1d: null,
rate_limit_7d: null,
enable_expiration: false, enable_expiration: false,
expiration_preset: '30', expiration_preset: '30',
expiration_date: '' expiration_date: ''
@@ -1190,6 +1493,37 @@ const resetQuotaUsed = async () => {
} }
} }
// Show reset rate limit confirmation dialog (from edit modal)
const confirmResetRateLimit = () => {
showResetRateLimitDialog.value = true
}
// Show reset rate limit confirmation dialog (from table row)
const confirmResetRateLimitFromTable = (row: ApiKey) => {
selectedKey.value = row
showResetRateLimitDialog.value = true
}
// Reset rate limit usage for an API key
const resetRateLimitUsage = async () => {
if (!selectedKey.value) return
showResetRateLimitDialog.value = false
try {
await keysAPI.update(selectedKey.value.id, { reset_rate_limit_usage: true })
appStore.showSuccess(t('keys.rateLimitResetSuccess'))
// Refresh key data
await loadApiKeys()
// Update the editing key with fresh data
const refreshedKey = apiKeys.value.find(k => k.id === selectedKey.value!.id)
if (refreshedKey) {
selectedKey.value = refreshedKey
}
} catch (error: any) {
const errorMsg = error.response?.data?.detail || t('keys.failedToResetRateLimit')
appStore.showError(errorMsg)
}
}
const importToCcswitch = (row: ApiKey) => { const importToCcswitch = (row: ApiKey) => {
const platform = row.group?.platform || 'anthropic' const platform = row.group?.platform || 'anthropic'