mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
Merge remote-tracking branch 'upstream/main' into feat/channel-insights
# Conflicts: # backend/cmd/server/wire_gen.go
This commit is contained in:
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -246,10 +246,10 @@ jobs:
|
|||||||
if [ -n "$DOCKERHUB_USERNAME" ]; then
|
if [ -n "$DOCKERHUB_USERNAME" ]; then
|
||||||
DOCKER_IMAGE="${DOCKERHUB_USERNAME}/sub2api"
|
DOCKER_IMAGE="${DOCKERHUB_USERNAME}/sub2api"
|
||||||
MESSAGE+="# Docker Hub"$'\n'
|
MESSAGE+="# Docker Hub"$'\n'
|
||||||
MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n'
|
MESSAGE+="docker pull ${DOCKER_IMAGE}:${VERSION}"$'\n'
|
||||||
MESSAGE+="# GitHub Container Registry"$'\n'
|
MESSAGE+="# GitHub Container Registry"$'\n'
|
||||||
fi
|
fi
|
||||||
MESSAGE+="docker pull ${GHCR_IMAGE}:${TAG_NAME}"$'\n'
|
MESSAGE+="docker pull ${GHCR_IMAGE}:${VERSION}"$'\n'
|
||||||
MESSAGE+="\`\`\`"$'\n'$'\n'
|
MESSAGE+="\`\`\`"$'\n'$'\n'
|
||||||
MESSAGE+="🔗 *相关链接:*"$'\n'
|
MESSAGE+="🔗 *相关链接:*"$'\n'
|
||||||
MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n'
|
MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n'
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
0.1.115
|
0.1.116
|
||||||
|
|||||||
@@ -61,8 +61,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingCache := repository.NewBillingCache(redisClient)
|
billingCache := repository.NewBillingCache(redisClient)
|
||||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
||||||
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
|
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
|
||||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig)
|
userRPMCache := repository.NewUserRPMCache(redisClient)
|
||||||
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
||||||
|
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
|
||||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
|
||||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||||
@@ -104,7 +105,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
privacyClientFactory := providePrivacyClientFactory()
|
privacyClientFactory := providePrivacyClientFactory()
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||||
@@ -124,9 +125,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||||
|
openAI403CounterCache := repository.NewOpenAI403CounterCache(redisClient)
|
||||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
|
||||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
||||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||||
@@ -136,7 +138,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||||
oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
@@ -183,6 +185,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
|
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
||||||
|
registry := payment.ProvideRegistry()
|
||||||
|
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||||
|
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
|
||||||
|
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
||||||
opsHandler := admin.NewOpsHandler(opsService)
|
opsHandler := admin.NewOpsHandler(opsService)
|
||||||
updateCache := repository.NewUpdateCache(redisClient)
|
updateCache := repository.NewUpdateCache(redisClient)
|
||||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||||
@@ -222,16 +233,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
|
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
|
||||||
channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
|
channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
|
||||||
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
||||||
registry := payment.ProvideRegistry()
|
|
||||||
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
|
||||||
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
|
||||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
|
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
|
||||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
|
||||||
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
||||||
availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
|
availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
|
||||||
@@ -261,6 +262,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
|
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
|||||||
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
|
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
|
||||||
pricingSvc := service.NewPricingService(cfg, nil)
|
pricingSvc := service.NewPricingService(cfg, nil)
|
||||||
emailQueueSvc := service.NewEmailQueueService(nil, 1)
|
emailQueueSvc := service.NewEmailQueueService(nil, 1)
|
||||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||||
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
|
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
|
||||||
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
|
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
|
||||||
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
|
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ type Group struct {
|
|||||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||||
// OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
|
// OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
|
||||||
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
|
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
|
||||||
|
// 分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流
|
||||||
|
RpmLimit int `json:"rpm_limit,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||||
Edges GroupEdges `json:"edges"`
|
Edges GroupEdges `json:"edges"`
|
||||||
@@ -191,7 +193,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
|
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
@@ -414,6 +416,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err)
|
return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case group.FieldRpmLimit:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.RpmLimit = int(value.Int64)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -599,6 +607,9 @@ func (_m *Group) String() string {
|
|||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
builder.WriteString("messages_dispatch_model_config=")
|
builder.WriteString("messages_dispatch_model_config=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig))
|
builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("rpm_limit=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,6 +76,8 @@ const (
|
|||||||
FieldDefaultMappedModel = "default_mapped_model"
|
FieldDefaultMappedModel = "default_mapped_model"
|
||||||
// FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
|
// FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
|
||||||
FieldMessagesDispatchModelConfig = "messages_dispatch_model_config"
|
FieldMessagesDispatchModelConfig = "messages_dispatch_model_config"
|
||||||
|
// FieldRpmLimit holds the string denoting the rpm_limit field in the database.
|
||||||
|
FieldRpmLimit = "rpm_limit"
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@@ -181,6 +183,7 @@ var Columns = []string{
|
|||||||
FieldRequirePrivacySet,
|
FieldRequirePrivacySet,
|
||||||
FieldDefaultMappedModel,
|
FieldDefaultMappedModel,
|
||||||
FieldMessagesDispatchModelConfig,
|
FieldMessagesDispatchModelConfig,
|
||||||
|
FieldRpmLimit,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -258,6 +261,8 @@ var (
|
|||||||
DefaultMappedModelValidator func(string) error
|
DefaultMappedModelValidator func(string) error
|
||||||
// DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
|
// DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
|
||||||
DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig
|
DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig
|
||||||
|
// DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
|
||||||
|
DefaultRpmLimit int
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the Group queries.
|
// OrderOption defines the ordering options for the Group queries.
|
||||||
@@ -403,6 +408,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByRpmLimit orders the results by the rpm_limit field.
|
||||||
|
func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldRpmLimit, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -190,6 +190,11 @@ func DefaultMappedModel(v string) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ.
|
||||||
|
func RpmLimit(v int) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -1320,6 +1325,46 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitEQ(v int) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitNEQ(v int) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNEQ(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitIn applies the In predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitIn(vs ...int) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldIn(FieldRpmLimit, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitNotIn(vs ...int) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNotIn(FieldRpmLimit, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitGT applies the GT predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitGT(v int) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldGT(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitGTE(v int) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldGTE(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitLT applies the LT predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitLT(v int) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldLT(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitLTE(v int) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldLTE(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.Group {
|
func HasAPIKeys() predicate.Group {
|
||||||
return predicate.Group(func(s *sql.Selector) {
|
return predicate.Group(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -425,6 +425,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate {
|
||||||
|
_c.mutation.SetRpmLimit(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||||
|
func (_c *GroupCreate) SetNillableRpmLimit(v *int) *GroupCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetRpmLimit(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -630,6 +644,10 @@ func (_c *GroupCreate) defaults() error {
|
|||||||
v := group.DefaultMessagesDispatchModelConfig
|
v := group.DefaultMessagesDispatchModelConfig
|
||||||
_c.mutation.SetMessagesDispatchModelConfig(v)
|
_c.mutation.SetMessagesDispatchModelConfig(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.RpmLimit(); !ok {
|
||||||
|
v := group.DefaultRpmLimit
|
||||||
|
_c.mutation.SetRpmLimit(v)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -717,6 +735,9 @@ func (_c *GroupCreate) check() error {
|
|||||||
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
|
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
|
||||||
return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)}
|
return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.RpmLimit(); !ok {
|
||||||
|
return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -864,6 +885,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
||||||
_node.MessagesDispatchModelConfig = value
|
_node.MessagesDispatchModelConfig = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.RpmLimit(); ok {
|
||||||
|
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
|
||||||
|
_node.RpmLimit = value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1500,6 +1525,24 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert {
|
||||||
|
u.Set(group.FieldRpmLimit, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateRpmLimit() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldRpmLimit)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||||
|
func (u *GroupUpsert) AddRpmLimit(v int) *GroupUpsert {
|
||||||
|
u.Add(group.FieldRpmLimit, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -2105,6 +2148,27 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetRpmLimit(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||||
|
func (u *GroupUpsertOne) AddRpmLimit(v int) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.AddRpmLimit(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateRpmLimit() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateRpmLimit()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -2876,6 +2940,27 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetRpmLimit(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||||
|
func (u *GroupUpsertBulk) AddRpmLimit(v int) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.AddRpmLimit(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateRpmLimit() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateRpmLimit()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -567,6 +567,27 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate {
|
||||||
|
_u.mutation.ResetRpmLimit()
|
||||||
|
_u.mutation.SetRpmLimit(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdate) SetNillableRpmLimit(v *int) *GroupUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRpmLimit(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds value to the "rpm_limit" field.
|
||||||
|
func (_u *GroupUpdate) AddRpmLimit(v int) *GroupUpdate {
|
||||||
|
_u.mutation.AddRpmLimit(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -1030,6 +1051,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
|
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
|
||||||
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.RpmLimit(); ok {
|
||||||
|
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedRpmLimit(); ok {
|
||||||
|
_spec.AddField(group.FieldRpmLimit, field.TypeInt, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1875,6 +1902,27 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne {
|
||||||
|
_u.mutation.ResetRpmLimit()
|
||||||
|
_u.mutation.SetRpmLimit(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdateOne) SetNillableRpmLimit(v *int) *GroupUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRpmLimit(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds value to the "rpm_limit" field.
|
||||||
|
func (_u *GroupUpdateOne) AddRpmLimit(v int) *GroupUpdateOne {
|
||||||
|
_u.mutation.AddRpmLimit(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -2368,6 +2416,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
|
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
|
||||||
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.RpmLimit(); ok {
|
||||||
|
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedRpmLimit(); ok {
|
||||||
|
_spec.AddField(group.FieldRpmLimit, field.TypeInt, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@@ -654,6 +654,7 @@ var (
|
|||||||
{Name: "require_privacy_set", Type: field.TypeBool, Default: false},
|
{Name: "require_privacy_set", Type: field.TypeBool, Default: false},
|
||||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||||
{Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
{Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||||
|
{Name: "rpm_limit", Type: field.TypeInt, Default: 0},
|
||||||
}
|
}
|
||||||
// GroupsTable holds the schema information for the "groups" table.
|
// GroupsTable holds the schema information for the "groups" table.
|
||||||
GroupsTable = &schema.Table{
|
GroupsTable = &schema.Table{
|
||||||
@@ -1447,6 +1448,7 @@ var (
|
|||||||
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
|
{Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
|
||||||
{Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
|
{Name: "rpm_limit", Type: field.TypeInt, Default: 0},
|
||||||
}
|
}
|
||||||
// UsersTable holds the schema information for the "users" table.
|
// UsersTable holds the schema information for the "users" table.
|
||||||
UsersTable = &schema.Table{
|
UsersTable = &schema.Table{
|
||||||
|
|||||||
@@ -14787,6 +14787,8 @@ type GroupMutation struct {
|
|||||||
require_privacy_set *bool
|
require_privacy_set *bool
|
||||||
default_mapped_model *string
|
default_mapped_model *string
|
||||||
messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
|
messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
|
||||||
|
rpm_limit *int
|
||||||
|
addrpm_limit *int
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
api_keys map[int64]struct{}
|
api_keys map[int64]struct{}
|
||||||
removedapi_keys map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
@@ -16375,6 +16377,62 @@ func (m *GroupMutation) ResetMessagesDispatchModelConfig() {
|
|||||||
m.messages_dispatch_model_config = nil
|
m.messages_dispatch_model_config = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (m *GroupMutation) SetRpmLimit(i int) {
|
||||||
|
m.rpm_limit = &i
|
||||||
|
m.addrpm_limit = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimit returns the value of the "rpm_limit" field in the mutation.
|
||||||
|
func (m *GroupMutation) RpmLimit() (r int, exists bool) {
|
||||||
|
v := m.rpm_limit
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldRpmLimit returns the old "rpm_limit" field's value of the Group entity.
|
||||||
|
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *GroupMutation) OldRpmLimit(ctx context.Context) (v int, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldRpmLimit requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.RpmLimit, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds i to the "rpm_limit" field.
|
||||||
|
func (m *GroupMutation) AddRpmLimit(i int) {
|
||||||
|
if m.addrpm_limit != nil {
|
||||||
|
*m.addrpm_limit += i
|
||||||
|
} else {
|
||||||
|
m.addrpm_limit = &i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation.
|
||||||
|
func (m *GroupMutation) AddedRpmLimit() (r int, exists bool) {
|
||||||
|
v := m.addrpm_limit
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetRpmLimit resets all changes to the "rpm_limit" field.
|
||||||
|
func (m *GroupMutation) ResetRpmLimit() {
|
||||||
|
m.rpm_limit = nil
|
||||||
|
m.addrpm_limit = nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@@ -16733,7 +16791,7 @@ func (m *GroupMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *GroupMutation) Fields() []string {
|
func (m *GroupMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 30)
|
fields := make([]string, 0, 31)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, group.FieldCreatedAt)
|
fields = append(fields, group.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -16824,6 +16882,9 @@ func (m *GroupMutation) Fields() []string {
|
|||||||
if m.messages_dispatch_model_config != nil {
|
if m.messages_dispatch_model_config != nil {
|
||||||
fields = append(fields, group.FieldMessagesDispatchModelConfig)
|
fields = append(fields, group.FieldMessagesDispatchModelConfig)
|
||||||
}
|
}
|
||||||
|
if m.rpm_limit != nil {
|
||||||
|
fields = append(fields, group.FieldRpmLimit)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -16892,6 +16953,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.DefaultMappedModel()
|
return m.DefaultMappedModel()
|
||||||
case group.FieldMessagesDispatchModelConfig:
|
case group.FieldMessagesDispatchModelConfig:
|
||||||
return m.MessagesDispatchModelConfig()
|
return m.MessagesDispatchModelConfig()
|
||||||
|
case group.FieldRpmLimit:
|
||||||
|
return m.RpmLimit()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -16961,6 +17024,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
|||||||
return m.OldDefaultMappedModel(ctx)
|
return m.OldDefaultMappedModel(ctx)
|
||||||
case group.FieldMessagesDispatchModelConfig:
|
case group.FieldMessagesDispatchModelConfig:
|
||||||
return m.OldMessagesDispatchModelConfig(ctx)
|
return m.OldMessagesDispatchModelConfig(ctx)
|
||||||
|
case group.FieldRpmLimit:
|
||||||
|
return m.OldRpmLimit(ctx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -17180,6 +17245,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetMessagesDispatchModelConfig(v)
|
m.SetMessagesDispatchModelConfig(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldRpmLimit:
|
||||||
|
v, ok := value.(int)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetRpmLimit(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -17221,6 +17293,9 @@ func (m *GroupMutation) AddedFields() []string {
|
|||||||
if m.addsort_order != nil {
|
if m.addsort_order != nil {
|
||||||
fields = append(fields, group.FieldSortOrder)
|
fields = append(fields, group.FieldSortOrder)
|
||||||
}
|
}
|
||||||
|
if m.addrpm_limit != nil {
|
||||||
|
fields = append(fields, group.FieldRpmLimit)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -17251,6 +17326,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedFallbackGroupIDOnInvalidRequest()
|
return m.AddedFallbackGroupIDOnInvalidRequest()
|
||||||
case group.FieldSortOrder:
|
case group.FieldSortOrder:
|
||||||
return m.AddedSortOrder()
|
return m.AddedSortOrder()
|
||||||
|
case group.FieldRpmLimit:
|
||||||
|
return m.AddedRpmLimit()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -17337,6 +17414,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddSortOrder(v)
|
m.AddSortOrder(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldRpmLimit:
|
||||||
|
v, ok := value.(int)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddRpmLimit(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group numeric field %s", name)
|
return fmt.Errorf("unknown Group numeric field %s", name)
|
||||||
}
|
}
|
||||||
@@ -17523,6 +17607,9 @@ func (m *GroupMutation) ResetField(name string) error {
|
|||||||
case group.FieldMessagesDispatchModelConfig:
|
case group.FieldMessagesDispatchModelConfig:
|
||||||
m.ResetMessagesDispatchModelConfig()
|
m.ResetMessagesDispatchModelConfig()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldRpmLimit:
|
||||||
|
m.ResetRpmLimit()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -37366,6 +37453,8 @@ type UserMutation struct {
|
|||||||
balance_notify_extra_emails *string
|
balance_notify_extra_emails *string
|
||||||
total_recharged *float64
|
total_recharged *float64
|
||||||
addtotal_recharged *float64
|
addtotal_recharged *float64
|
||||||
|
rpm_limit *int
|
||||||
|
addrpm_limit *int
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
api_keys map[int64]struct{}
|
api_keys map[int64]struct{}
|
||||||
removedapi_keys map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
@@ -38457,6 +38546,62 @@ func (m *UserMutation) ResetTotalRecharged() {
|
|||||||
m.addtotal_recharged = nil
|
m.addtotal_recharged = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (m *UserMutation) SetRpmLimit(i int) {
|
||||||
|
m.rpm_limit = &i
|
||||||
|
m.addrpm_limit = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimit returns the value of the "rpm_limit" field in the mutation.
|
||||||
|
func (m *UserMutation) RpmLimit() (r int, exists bool) {
|
||||||
|
v := m.rpm_limit
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldRpmLimit returns the old "rpm_limit" field's value of the User entity.
|
||||||
|
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UserMutation) OldRpmLimit(ctx context.Context) (v int, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldRpmLimit requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.RpmLimit, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds i to the "rpm_limit" field.
|
||||||
|
func (m *UserMutation) AddRpmLimit(i int) {
|
||||||
|
if m.addrpm_limit != nil {
|
||||||
|
*m.addrpm_limit += i
|
||||||
|
} else {
|
||||||
|
m.addrpm_limit = &i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation.
|
||||||
|
func (m *UserMutation) AddedRpmLimit() (r int, exists bool) {
|
||||||
|
v := m.addrpm_limit
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetRpmLimit resets all changes to the "rpm_limit" field.
|
||||||
|
func (m *UserMutation) ResetRpmLimit() {
|
||||||
|
m.rpm_limit = nil
|
||||||
|
m.addrpm_limit = nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@@ -39139,7 +39284,7 @@ func (m *UserMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UserMutation) Fields() []string {
|
func (m *UserMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 22)
|
fields := make([]string, 0, 23)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, user.FieldCreatedAt)
|
fields = append(fields, user.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -39206,6 +39351,9 @@ func (m *UserMutation) Fields() []string {
|
|||||||
if m.total_recharged != nil {
|
if m.total_recharged != nil {
|
||||||
fields = append(fields, user.FieldTotalRecharged)
|
fields = append(fields, user.FieldTotalRecharged)
|
||||||
}
|
}
|
||||||
|
if m.rpm_limit != nil {
|
||||||
|
fields = append(fields, user.FieldRpmLimit)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39258,6 +39406,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.BalanceNotifyExtraEmails()
|
return m.BalanceNotifyExtraEmails()
|
||||||
case user.FieldTotalRecharged:
|
case user.FieldTotalRecharged:
|
||||||
return m.TotalRecharged()
|
return m.TotalRecharged()
|
||||||
|
case user.FieldRpmLimit:
|
||||||
|
return m.RpmLimit()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -39311,6 +39461,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
|
|||||||
return m.OldBalanceNotifyExtraEmails(ctx)
|
return m.OldBalanceNotifyExtraEmails(ctx)
|
||||||
case user.FieldTotalRecharged:
|
case user.FieldTotalRecharged:
|
||||||
return m.OldTotalRecharged(ctx)
|
return m.OldTotalRecharged(ctx)
|
||||||
|
case user.FieldRpmLimit:
|
||||||
|
return m.OldRpmLimit(ctx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown User field %s", name)
|
return nil, fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
@@ -39474,6 +39626,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetTotalRecharged(v)
|
m.SetTotalRecharged(v)
|
||||||
return nil
|
return nil
|
||||||
|
case user.FieldRpmLimit:
|
||||||
|
v, ok := value.(int)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetRpmLimit(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User field %s", name)
|
return fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
@@ -39494,6 +39653,9 @@ func (m *UserMutation) AddedFields() []string {
|
|||||||
if m.addtotal_recharged != nil {
|
if m.addtotal_recharged != nil {
|
||||||
fields = append(fields, user.FieldTotalRecharged)
|
fields = append(fields, user.FieldTotalRecharged)
|
||||||
}
|
}
|
||||||
|
if m.addrpm_limit != nil {
|
||||||
|
fields = append(fields, user.FieldRpmLimit)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39510,6 +39672,8 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedBalanceNotifyThreshold()
|
return m.AddedBalanceNotifyThreshold()
|
||||||
case user.FieldTotalRecharged:
|
case user.FieldTotalRecharged:
|
||||||
return m.AddedTotalRecharged()
|
return m.AddedTotalRecharged()
|
||||||
|
case user.FieldRpmLimit:
|
||||||
|
return m.AddedRpmLimit()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -39547,6 +39711,13 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddTotalRecharged(v)
|
m.AddTotalRecharged(v)
|
||||||
return nil
|
return nil
|
||||||
|
case user.FieldRpmLimit:
|
||||||
|
v, ok := value.(int)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddRpmLimit(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User numeric field %s", name)
|
return fmt.Errorf("unknown User numeric field %s", name)
|
||||||
}
|
}
|
||||||
@@ -39679,6 +39850,9 @@ func (m *UserMutation) ResetField(name string) error {
|
|||||||
case user.FieldTotalRecharged:
|
case user.FieldTotalRecharged:
|
||||||
m.ResetTotalRecharged()
|
m.ResetTotalRecharged()
|
||||||
return nil
|
return nil
|
||||||
|
case user.FieldRpmLimit:
|
||||||
|
m.ResetRpmLimit()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User field %s", name)
|
return fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -845,6 +845,10 @@ func init() {
|
|||||||
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
|
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
|
||||||
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
|
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
|
||||||
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
|
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
|
||||||
|
// groupDescRpmLimit is the schema descriptor for rpm_limit field.
|
||||||
|
groupDescRpmLimit := groupFields[27].Descriptor()
|
||||||
|
// group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
|
||||||
|
group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
|
||||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||||
_ = idempotencyrecordMixinFields0
|
_ = idempotencyrecordMixinFields0
|
||||||
@@ -1825,6 +1829,10 @@ func init() {
|
|||||||
userDescTotalRecharged := userFields[18].Descriptor()
|
userDescTotalRecharged := userFields[18].Descriptor()
|
||||||
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
|
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
|
||||||
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
|
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
|
||||||
|
// userDescRpmLimit is the schema descriptor for rpm_limit field.
|
||||||
|
userDescRpmLimit := userFields[19].Descriptor()
|
||||||
|
// user.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
|
||||||
|
user.DefaultRpmLimit = userDescRpmLimit.Default.(int)
|
||||||
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
||||||
_ = userallowedgroupFields
|
_ = userallowedgroupFields
|
||||||
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
||||||
|
|||||||
@@ -145,6 +145,11 @@ func (Group) Fields() []ent.Field {
|
|||||||
Default(domain.OpenAIMessagesDispatchModelConfig{}).
|
Default(domain.OpenAIMessagesDispatchModelConfig{}).
|
||||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||||
Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"),
|
Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"),
|
||||||
|
|
||||||
|
// 分组级每分钟请求数上限(0 = 不限制)。设置后优先于用户级兜底生效。
|
||||||
|
field.Int("rpm_limit").
|
||||||
|
Default(0).
|
||||||
|
Comment("分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -108,6 +108,10 @@ func (User) Fields() []ent.Field {
|
|||||||
field.Float("total_recharged").
|
field.Float("total_recharged").
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||||
Default(0),
|
Default(0),
|
||||||
|
|
||||||
|
// 用户级每分钟请求数上限(0 = 不限制)。仅当所在分组未设置 rpm_limit 时作为兜底生效。
|
||||||
|
field.Int("rpm_limit").
|
||||||
|
Default(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -61,6 +61,8 @@ type User struct {
|
|||||||
BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
|
BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
|
||||||
// TotalRecharged holds the value of the "total_recharged" field.
|
// TotalRecharged holds the value of the "total_recharged" field.
|
||||||
TotalRecharged float64 `json:"total_recharged,omitempty"`
|
TotalRecharged float64 `json:"total_recharged,omitempty"`
|
||||||
|
// RpmLimit holds the value of the "rpm_limit" field.
|
||||||
|
RpmLimit int `json:"rpm_limit,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the UserQuery when eager-loading is set.
|
// The values are being populated by the UserQuery when eager-loading is set.
|
||||||
Edges UserEdges `json:"edges"`
|
Edges UserEdges `json:"edges"`
|
||||||
@@ -226,7 +228,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
|
case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case user.FieldID, user.FieldConcurrency:
|
case user.FieldID, user.FieldConcurrency, user.FieldRpmLimit:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
|
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
@@ -391,6 +393,12 @@ func (_m *User) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.TotalRecharged = value.Float64
|
_m.TotalRecharged = value.Float64
|
||||||
}
|
}
|
||||||
|
case user.FieldRpmLimit:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.RpmLimit = int(value.Int64)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -569,6 +577,9 @@ func (_m *User) String() string {
|
|||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
builder.WriteString("total_recharged=")
|
builder.WriteString("total_recharged=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
|
builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("rpm_limit=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ const (
|
|||||||
FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
|
FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
|
||||||
// FieldTotalRecharged holds the string denoting the total_recharged field in the database.
|
// FieldTotalRecharged holds the string denoting the total_recharged field in the database.
|
||||||
FieldTotalRecharged = "total_recharged"
|
FieldTotalRecharged = "total_recharged"
|
||||||
|
// FieldRpmLimit holds the string denoting the rpm_limit field in the database.
|
||||||
|
FieldRpmLimit = "rpm_limit"
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@@ -203,6 +205,7 @@ var Columns = []string{
|
|||||||
FieldBalanceNotifyThreshold,
|
FieldBalanceNotifyThreshold,
|
||||||
FieldBalanceNotifyExtraEmails,
|
FieldBalanceNotifyExtraEmails,
|
||||||
FieldTotalRecharged,
|
FieldTotalRecharged,
|
||||||
|
FieldRpmLimit,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -271,6 +274,8 @@ var (
|
|||||||
DefaultBalanceNotifyExtraEmails string
|
DefaultBalanceNotifyExtraEmails string
|
||||||
// DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
|
// DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
|
||||||
DefaultTotalRecharged float64
|
DefaultTotalRecharged float64
|
||||||
|
// DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
|
||||||
|
DefaultRpmLimit int
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the User queries.
|
// OrderOption defines the ordering options for the User queries.
|
||||||
@@ -391,6 +396,11 @@ func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
|
return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByRpmLimit orders the results by the rpm_limit field.
|
||||||
|
func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldRpmLimit, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -165,6 +165,11 @@ func TotalRecharged(v float64) predicate.User {
|
|||||||
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
|
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ.
|
||||||
|
func RpmLimit(v int) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.User {
|
func CreatedAtEQ(v time.Time) predicate.User {
|
||||||
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -1295,6 +1300,46 @@ func TotalRechargedLTE(v float64) predicate.User {
|
|||||||
return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
|
return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitEQ(v int) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitNEQ(v int) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitIn applies the In predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitIn(vs ...int) predicate.User {
|
||||||
|
return predicate.User(sql.FieldIn(FieldRpmLimit, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitNotIn(vs ...int) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotIn(FieldRpmLimit, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitGT applies the GT predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitGT(v int) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGT(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitGTE(v int) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGTE(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitLT applies the LT predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitLT(v int) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLT(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field.
|
||||||
|
func RpmLimitLTE(v int) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLTE(FieldRpmLimit, v))
|
||||||
|
}
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.User {
|
func HasAPIKeys() predicate.User {
|
||||||
return predicate.User(func(s *sql.Selector) {
|
return predicate.User(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -325,6 +325,20 @@ func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (_c *UserCreate) SetRpmLimit(v int) *UserCreate {
|
||||||
|
_c.mutation.SetRpmLimit(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableRpmLimit(v *int) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetRpmLimit(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -604,6 +618,10 @@ func (_c *UserCreate) defaults() error {
|
|||||||
v := user.DefaultTotalRecharged
|
v := user.DefaultTotalRecharged
|
||||||
_c.mutation.SetTotalRecharged(v)
|
_c.mutation.SetTotalRecharged(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.RpmLimit(); !ok {
|
||||||
|
v := user.DefaultRpmLimit
|
||||||
|
_c.mutation.SetRpmLimit(v)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -687,6 +705,9 @@ func (_c *UserCreate) check() error {
|
|||||||
if _, ok := _c.mutation.TotalRecharged(); !ok {
|
if _, ok := _c.mutation.TotalRecharged(); !ok {
|
||||||
return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
|
return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.RpmLimit(); !ok {
|
||||||
|
return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "User.rpm_limit"`)}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -802,6 +823,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||||
_node.TotalRecharged = value
|
_node.TotalRecharged = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.RpmLimit(); ok {
|
||||||
|
_spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
|
||||||
|
_node.RpmLimit = value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1362,6 +1387,24 @@ func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (u *UserUpsert) SetRpmLimit(v int) *UserUpsert {
|
||||||
|
u.Set(user.FieldRpmLimit, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateRpmLimit() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldRpmLimit)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||||
|
func (u *UserUpsert) AddRpmLimit(v int) *UserUpsert {
|
||||||
|
u.Add(user.FieldRpmLimit, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -1771,6 +1814,27 @@ func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (u *UserUpsertOne) SetRpmLimit(v int) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetRpmLimit(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||||
|
func (u *UserUpsertOne) AddRpmLimit(v int) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.AddRpmLimit(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateRpmLimit() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateRpmLimit()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -2346,6 +2410,27 @@ func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (u *UserUpsertBulk) SetRpmLimit(v int) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetRpmLimit(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||||
|
func (u *UserUpsertBulk) AddRpmLimit(v int) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.AddRpmLimit(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateRpmLimit() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateRpmLimit()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -389,6 +389,27 @@ func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (_u *UserUpdate) SetRpmLimit(v int) *UserUpdate {
|
||||||
|
_u.mutation.ResetRpmLimit()
|
||||||
|
_u.mutation.SetRpmLimit(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableRpmLimit(v *int) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRpmLimit(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds value to the "rpm_limit" field.
|
||||||
|
func (_u *UserUpdate) AddRpmLimit(v int) *UserUpdate {
|
||||||
|
_u.mutation.AddRpmLimit(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -1008,6 +1029,12 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
|
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
|
||||||
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.RpmLimit(); ok {
|
||||||
|
_spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedRpmLimit(); ok {
|
||||||
|
_spec.AddField(user.FieldRpmLimit, field.TypeInt, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1930,6 +1957,27 @@ func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRpmLimit sets the "rpm_limit" field.
|
||||||
|
func (_u *UserUpdateOne) SetRpmLimit(v int) *UserUpdateOne {
|
||||||
|
_u.mutation.ResetRpmLimit()
|
||||||
|
_u.mutation.SetRpmLimit(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableRpmLimit(v *int) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRpmLimit(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRpmLimit adds value to the "rpm_limit" field.
|
||||||
|
func (_u *UserUpdateOne) AddRpmLimit(v int) *UserUpdateOne {
|
||||||
|
_u.mutation.AddRpmLimit(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -2579,6 +2627,12 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
|||||||
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
|
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
|
||||||
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.RpmLimit(); ok {
|
||||||
|
_spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedRpmLimit(); ok {
|
||||||
|
_spec.AddField(user.FieldRpmLimit, field.TypeInt, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@@ -104,6 +104,7 @@ require (
|
|||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
github.com/google/go-cmp v0.7.0 // indirect
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
github.com/google/go-querystring v1.1.0 // indirect
|
github.com/google/go-querystring v1.1.0 // indirect
|
||||||
|
github.com/google/subcommands v1.2.0 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||||
|
|||||||
@@ -183,6 +183,17 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
|
|||||||
return map[string]any{"user_id": userID}, nil
|
return map[string]any{"user_id": userID}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) GetUserRPMStatus(ctx context.Context, userID int64) (*service.UserRPMStatus, error) {
|
||||||
|
user, err := s.GetUser(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &service.UserRPMStatus{
|
||||||
|
UserRPMUsed: 0,
|
||||||
|
UserRPMLimit: user.RPMLimit,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
|
func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
|
||||||
s.boundAuthIdentityFor = userID
|
s.boundAuthIdentityFor = userID
|
||||||
copied := input
|
copied := input
|
||||||
@@ -276,6 +287,14 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, _ []service.GroupRPMOverrideInput) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
|
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
|
||||||
s.lastListAccounts.platform = platform
|
s.lastListAccounts.platform = platform
|
||||||
s.lastListAccounts.accountType = accountType
|
s.lastListAccounts.accountType = accountType
|
||||||
|
|||||||
@@ -110,6 +110,8 @@ type CreateGroupRequest struct {
|
|||||||
RequirePrivacySet bool `json:"require_privacy_set"`
|
RequirePrivacySet bool `json:"require_privacy_set"`
|
||||||
DefaultMappedModel string `json:"default_mapped_model"`
|
DefaultMappedModel string `json:"default_mapped_model"`
|
||||||
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
||||||
|
// 分组 RPM 上限(0 = 不限制)
|
||||||
|
RPMLimit int `json:"rpm_limit"`
|
||||||
// 从指定分组复制账号(创建后自动绑定)
|
// 从指定分组复制账号(创建后自动绑定)
|
||||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||||
}
|
}
|
||||||
@@ -145,6 +147,8 @@ type UpdateGroupRequest struct {
|
|||||||
RequirePrivacySet *bool `json:"require_privacy_set"`
|
RequirePrivacySet *bool `json:"require_privacy_set"`
|
||||||
DefaultMappedModel *string `json:"default_mapped_model"`
|
DefaultMappedModel *string `json:"default_mapped_model"`
|
||||||
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
||||||
|
// 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动
|
||||||
|
RPMLimit *int `json:"rpm_limit"`
|
||||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||||
}
|
}
|
||||||
@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
RequirePrivacySet: req.RequirePrivacySet,
|
RequirePrivacySet: req.RequirePrivacySet,
|
||||||
DefaultMappedModel: req.DefaultMappedModel,
|
DefaultMappedModel: req.DefaultMappedModel,
|
||||||
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
|
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
|
||||||
|
RPMLimit: req.RPMLimit,
|
||||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
RequirePrivacySet: req.RequirePrivacySet,
|
RequirePrivacySet: req.RequirePrivacySet,
|
||||||
DefaultMappedModel: req.DefaultMappedModel,
|
DefaultMappedModel: req.DefaultMappedModel,
|
||||||
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
|
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
|
||||||
|
RPMLimit: req.RPMLimit,
|
||||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request
|
||||||
|
type BatchSetGroupRPMOverridesRequest struct {
|
||||||
|
Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group
|
||||||
|
// PUT /api/v1/admin/groups/:id/rpm-overrides
|
||||||
|
func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req BatchSetGroupRPMOverridesRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "RPM overrides updated successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearGroupRPMOverrides handles clearing all rpm_override for a group
|
||||||
|
// DELETE /api/v1/admin/groups/:id/rpm-overrides
|
||||||
|
func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "RPM overrides cleared successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||||
type UpdateSortOrderRequest struct {
|
type UpdateSortOrderRequest struct {
|
||||||
Updates []struct {
|
Updates []struct {
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||||
DefaultConcurrency: settings.DefaultConcurrency,
|
DefaultConcurrency: settings.DefaultConcurrency,
|
||||||
DefaultBalance: settings.DefaultBalance,
|
DefaultBalance: settings.DefaultBalance,
|
||||||
|
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
|
||||||
DefaultSubscriptions: defaultSubscriptions,
|
DefaultSubscriptions: defaultSubscriptions,
|
||||||
EnableModelFallback: settings.EnableModelFallback,
|
EnableModelFallback: settings.EnableModelFallback,
|
||||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||||
@@ -337,6 +338,7 @@ type UpdateSettingsRequest struct {
|
|||||||
// 默认配置
|
// 默认配置
|
||||||
DefaultConcurrency int `json:"default_concurrency"`
|
DefaultConcurrency int `json:"default_concurrency"`
|
||||||
DefaultBalance float64 `json:"default_balance"`
|
DefaultBalance float64 `json:"default_balance"`
|
||||||
|
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||||
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
|
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
|
||||||
@@ -1117,6 +1119,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
CustomEndpoints: customEndpointsJSON,
|
CustomEndpoints: customEndpointsJSON,
|
||||||
DefaultConcurrency: req.DefaultConcurrency,
|
DefaultConcurrency: req.DefaultConcurrency,
|
||||||
DefaultBalance: req.DefaultBalance,
|
DefaultBalance: req.DefaultBalance,
|
||||||
|
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||||
DefaultSubscriptions: defaultSubscriptions,
|
DefaultSubscriptions: defaultSubscriptions,
|
||||||
EnableModelFallback: req.EnableModelFallback,
|
EnableModelFallback: req.EnableModelFallback,
|
||||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||||
@@ -1430,6 +1433,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
|
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
|
||||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||||
DefaultBalance: updatedSettings.DefaultBalance,
|
DefaultBalance: updatedSettings.DefaultBalance,
|
||||||
|
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
|
||||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ type CreateUserRequest struct {
|
|||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
Balance float64 `json:"balance"`
|
Balance float64 `json:"balance"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
|
RPMLimit int `json:"rpm_limit"`
|
||||||
AllowedGroups []int64 `json:"allowed_groups"`
|
AllowedGroups []int64 `json:"allowed_groups"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,6 +53,7 @@ type UpdateUserRequest struct {
|
|||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Balance *float64 `json:"balance"`
|
Balance *float64 `json:"balance"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
|
RPMLimit *int `json:"rpm_limit"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||||
// GroupRates 用户专属分组倍率配置
|
// GroupRates 用户专属分组倍率配置
|
||||||
@@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) {
|
|||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
Balance: req.Balance,
|
Balance: req.Balance,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
|
RPMLimit: req.RPMLimit,
|
||||||
AllowedGroups: req.AllowedGroups,
|
AllowedGroups: req.AllowedGroups,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
|||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
Balance: req.Balance,
|
Balance: req.Balance,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
|
RPMLimit: req.RPMLimit,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
AllowedGroups: req.AllowedGroups,
|
AllowedGroups: req.AllowedGroups,
|
||||||
GroupRates: req.GroupRates,
|
GroupRates: req.GroupRates,
|
||||||
@@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
|
|||||||
"migrated_keys": result.MigratedKeys,
|
"migrated_keys": result.MigratedKeys,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量
|
||||||
|
// GET /api/v1/admin/users/:id/rpm-status
|
||||||
|
func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
|
||||||
|
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, status)
|
||||||
|
}
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User {
|
|||||||
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
||||||
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
|
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
|
||||||
TotalRecharged: u.TotalRecharged,
|
TotalRecharged: u.TotalRecharged,
|
||||||
|
RPMLimit: u.RPMLimit,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group {
|
|||||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||||
RequireOAuthOnly: g.RequireOAuthOnly,
|
RequireOAuthOnly: g.RequireOAuthOnly,
|
||||||
RequirePrivacySet: g.RequirePrivacySet,
|
RequirePrivacySet: g.RequirePrivacySet,
|
||||||
|
RPMLimit: g.RPMLimit,
|
||||||
CreatedAt: g.CreatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
UpdatedAt: g.UpdatedAt,
|
UpdatedAt: g.UpdatedAt,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ type SystemSettings struct {
|
|||||||
|
|
||||||
DefaultConcurrency int `json:"default_concurrency"`
|
DefaultConcurrency int `json:"default_concurrency"`
|
||||||
DefaultBalance float64 `json:"default_balance"`
|
DefaultBalance float64 `json:"default_balance"`
|
||||||
|
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||||
|
|
||||||
// Model fallback configuration
|
// Model fallback configuration
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ type User struct {
|
|||||||
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
|
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
|
||||||
TotalRecharged float64 `json:"total_recharged"`
|
TotalRecharged float64 `json:"total_recharged"`
|
||||||
|
|
||||||
|
// RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。
|
||||||
|
RPMLimit int `json:"rpm_limit"`
|
||||||
|
|
||||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -108,6 +111,9 @@ type Group struct {
|
|||||||
RequireOAuthOnly bool `json:"require_oauth_only"`
|
RequireOAuthOnly bool `json:"require_oauth_only"`
|
||||||
RequirePrivacySet bool `json:"require_privacy_set"`
|
RequirePrivacySet bool `json:"require_privacy_set"`
|
||||||
|
|
||||||
|
// RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。
|
||||||
|
RPMLimit int `json:"rpm_limit"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 2. 【新增】Wait后二次检查余额/订阅
|
// 2. 【新增】Wait后二次检查余额/订阅
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
|
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
|
||||||
status, code, message := billingErrorDetails(err)
|
status, code, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -758,7 +761,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
|
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
|
||||||
status, code, message := billingErrorDetails(err)
|
status, code, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1464,7 +1470,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
// 校验 billing eligibility(订阅/余额)
|
// 校验 billing eligibility(订阅/余额)
|
||||||
// 【注意】不计算并发,但需要校验订阅/余额
|
// 【注意】不计算并发,但需要校验订阅/余额
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
status, code, message := billingErrorDetails(err)
|
status, code, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
h.errorResponse(c, status, code, message)
|
h.errorResponse(c, status, code, message)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1707,25 +1716,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
|
|||||||
c.JSON(http.StatusOK, response)
|
c.JSON(http.StatusOK, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func billingErrorDetails(err error) (status int, code, message string) {
|
func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
|
||||||
if errors.Is(err, service.ErrBillingServiceUnavailable) {
|
if errors.Is(err, service.ErrBillingServiceUnavailable) {
|
||||||
msg := pkgerrors.Message(err)
|
msg := pkgerrors.Message(err)
|
||||||
if msg == "" {
|
if msg == "" {
|
||||||
msg = "Billing service temporarily unavailable. Please retry later."
|
msg = "Billing service temporarily unavailable. Please retry later."
|
||||||
}
|
}
|
||||||
return http.StatusServiceUnavailable, "billing_service_error", msg
|
return http.StatusServiceUnavailable, "billing_service_error", msg, 0
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
|
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
|
||||||
msg := pkgerrors.Message(err)
|
msg := pkgerrors.Message(err)
|
||||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
|
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
|
||||||
msg := pkgerrors.Message(err)
|
msg := pkgerrors.Message(err)
|
||||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
|
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
|
||||||
msg := pkgerrors.Message(err)
|
msg := pkgerrors.Message(err)
|
||||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
|
||||||
|
}
|
||||||
|
// 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。
|
||||||
|
// 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。
|
||||||
|
if errors.Is(err, service.ErrGroupRPMExceeded) || errors.Is(err, service.ErrUserRPMExceeded) {
|
||||||
|
msg := pkgerrors.Message(err)
|
||||||
|
retrySeconds := 60 - int(time.Now().Unix()%60)
|
||||||
|
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
|
||||||
}
|
}
|
||||||
msg := pkgerrors.Message(err)
|
msg := pkgerrors.Message(err)
|
||||||
if msg == "" {
|
if msg == "" {
|
||||||
@@ -1735,7 +1751,7 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
|||||||
).Warn("gateway.billing_error_missing_message")
|
).Warn("gateway.billing_error_missing_message")
|
||||||
msg = "Billing error"
|
msg = "Billing error"
|
||||||
}
|
}
|
||||||
return http.StatusForbidden, "billing_error", msg
|
return http.StatusForbidden, "billing_error", msg, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *GatewayHandler) metadataBridgeEnabled() bool {
|
func (h *GatewayHandler) metadataBridgeEnabled() bool {
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) {
|
||||||
|
status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, status)
|
||||||
|
require.Equal(t, "rate_limit_exceeded", code)
|
||||||
|
require.NotEmpty(t, msg)
|
||||||
|
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
|
||||||
|
require.LessOrEqual(t, retryAfter, 60)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) {
|
||||||
|
status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, status)
|
||||||
|
require.Equal(t, "rate_limit_exceeded", code)
|
||||||
|
require.NotEmpty(t, msg)
|
||||||
|
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
|
||||||
|
require.LessOrEqual(t, retryAfter, 60)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) {
|
||||||
|
// 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。
|
||||||
|
for _, err := range []error{
|
||||||
|
service.ErrAPIKeyRateLimit5hExceeded,
|
||||||
|
service.ErrAPIKeyRateLimit1dExceeded,
|
||||||
|
service.ErrAPIKeyRateLimit7dExceeded,
|
||||||
|
} {
|
||||||
|
status, code, _, _ := billingErrorDetails(err)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err)
|
||||||
|
require.Equal(t, "rate_limit_exceeded", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) {
|
||||||
|
status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable)
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, status)
|
||||||
|
require.Equal(t, "billing_service_error", code)
|
||||||
|
require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
|
||||||
|
status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance)
|
||||||
|
require.Equal(t, http.StatusForbidden, status)
|
||||||
|
require.Equal(t, "billing_error", code)
|
||||||
|
require.NotEmpty(t, msg)
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
@@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
// 2. Re-check billing
|
// 2. Re-check billing
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
|
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
|
||||||
status, code, message := billingErrorDetails(err)
|
status, code, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
h.chatCompletionsErrorResponse(c, status, code, message)
|
h.chatCompletionsErrorResponse(c, status, code, message)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
@@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 2. Re-check billing
|
// 2. Re-check billing
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
|
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
|
||||||
status, code, message := billingErrorDetails(err)
|
status, code, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
h.responsesErrorResponse(c, status, code, message)
|
h.responsesErrorResponse(c, status, code, message)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
|||||||
|
|
||||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||||
|
|
||||||
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||||
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
@@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
// 2) billing eligibility check (after wait)
|
// 2) billing eligibility check (after wait)
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
|
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
|
||||||
status, _, message := billingErrorDetails(err)
|
status, _, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
googleError(c, status, message)
|
googleError(c, status, message)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
@@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||||
status, code, message := billingErrorDetails(err)
|
status, code, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 2. Re-check billing eligibility after wait
|
// 2. Re-check billing eligibility after wait
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
|
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
|
||||||
status, code, message := billingErrorDetails(err)
|
status, code, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
|
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
|
||||||
status, code, message := billingErrorDetails(err)
|
status, code, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
|
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
|||||||
|
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
|
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
|
||||||
status, code, message := billingErrorDetails(err)
|
status, code, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
user.FieldSignupSource,
|
user.FieldSignupSource,
|
||||||
user.FieldLastLoginAt,
|
user.FieldLastLoginAt,
|
||||||
user.FieldLastActiveAt,
|
user.FieldLastActiveAt,
|
||||||
|
user.FieldRpmLimit,
|
||||||
)
|
)
|
||||||
}).
|
}).
|
||||||
WithGroup(func(q *dbent.GroupQuery) {
|
WithGroup(func(q *dbent.GroupQuery) {
|
||||||
@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
group.FieldAllowMessagesDispatch,
|
group.FieldAllowMessagesDispatch,
|
||||||
group.FieldDefaultMappedModel,
|
group.FieldDefaultMappedModel,
|
||||||
group.FieldMessagesDispatchModelConfig,
|
group.FieldMessagesDispatchModelConfig,
|
||||||
|
group.FieldRpmLimit,
|
||||||
)
|
)
|
||||||
}).
|
}).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
|
|||||||
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
|
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
|
||||||
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
||||||
TotalRecharged: u.TotalRecharged,
|
TotalRecharged: u.TotalRecharged,
|
||||||
|
RPMLimit: u.RpmLimit,
|
||||||
CreatedAt: u.CreatedAt,
|
CreatedAt: u.CreatedAt,
|
||||||
UpdatedAt: u.UpdatedAt,
|
UpdatedAt: u.UpdatedAt,
|
||||||
}
|
}
|
||||||
@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
RequirePrivacySet: g.RequirePrivacySet,
|
RequirePrivacySet: g.RequirePrivacySet,
|
||||||
DefaultMappedModel: g.DefaultMappedModel,
|
DefaultMappedModel: g.DefaultMappedModel,
|
||||||
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
|
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
|
||||||
|
RPMLimit: g.RpmLimit,
|
||||||
CreatedAt: g.CreatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
UpdatedAt: g.UpdatedAt,
|
UpdatedAt: g.UpdatedAt,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
|
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
|
||||||
|
SetRpmLimit(groupIn.RPMLimit)
|
||||||
|
|
||||||
// 设置模型路由配置
|
// 设置模型路由配置
|
||||||
if groupIn.ModelRouting != nil {
|
if groupIn.ModelRouting != nil {
|
||||||
@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
|
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
|
||||||
|
SetRpmLimit(groupIn.RPMLimit)
|
||||||
|
|
||||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||||
if groupIn.DailyLimitUSD != nil {
|
if groupIn.DailyLimitUSD != nil {
|
||||||
|
|||||||
51
backend/internal/repository/openai_403_counter_cache.go
Normal file
51
backend/internal/repository/openai_403_counter_cache.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const openAI403CounterPrefix = "openai_403_count:account:"
|
||||||
|
|
||||||
|
var openAI403CounterIncrScript = redis.NewScript(`
|
||||||
|
local key = KEYS[1]
|
||||||
|
local ttl = tonumber(ARGV[1])
|
||||||
|
|
||||||
|
local count = redis.call('INCR', key)
|
||||||
|
if count == 1 then
|
||||||
|
redis.call('EXPIRE', key, ttl)
|
||||||
|
end
|
||||||
|
|
||||||
|
return count
|
||||||
|
`)
|
||||||
|
|
||||||
|
type openAI403CounterCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache {
|
||||||
|
return &openAI403CounterCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
|
||||||
|
|
||||||
|
ttlSeconds := windowMinutes * 60
|
||||||
|
if ttlSeconds < 60 {
|
||||||
|
ttlSeconds = 60
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("increment openai 403 count: %w", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -53,6 +54,9 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
|||||||
Post(s.tokenURL)
|
Post(s.tokenURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) {
|
||||||
|
return nil, newOpenAINoProxyHintError(err)
|
||||||
|
}
|
||||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,6 +102,9 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre
|
|||||||
Post(s.tokenURL)
|
Post(s.tokenURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) {
|
||||||
|
return nil, newOpenAINoProxyHintError(err)
|
||||||
|
}
|
||||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,3 +121,21 @@ func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
|
|||||||
Timeout: 120 * time.Second,
|
Timeout: 120 * time.Second,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldReturnOpenAINoProxyHint(ctx context.Context, proxyURL string, err error) bool {
|
||||||
|
if strings.TrimSpace(proxyURL) != "" || err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if ctx != nil && ctx.Err() != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !errors.Is(err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAINoProxyHintError(cause error) error {
|
||||||
|
return infraerrors.New(
|
||||||
|
http.StatusBadGateway,
|
||||||
|
"OPENAI_OAUTH_PROXY_REQUIRED",
|
||||||
|
"OpenAI OAuth request failed: no proxy is configured and this server could not reach OpenAI directly. Select a proxy that can access OpenAI, then retry; if the authorization code has expired, regenerate the authorization URL.",
|
||||||
|
).WithCause(cause)
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@@ -204,6 +205,17 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
|
|||||||
require.ErrorContains(s.T(), err, "request failed")
|
require.ErrorContains(s.T(), err, "request failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_RequestErrorWithoutProxyReturnsProxyHint() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
s.srv.Close()
|
||||||
|
|
||||||
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||||
|
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.Equal(s.T(), "OPENAI_OAUTH_PROXY_REQUIRED", infraerrors.Reason(err))
|
||||||
|
require.Contains(s.T(), infraerrors.Message(err), "no proxy is configured")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
|
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
|
||||||
started := make(chan struct{})
|
started := make(chan struct{})
|
||||||
block := make(chan struct{})
|
block := make(chan struct{})
|
||||||
|
|||||||
@@ -290,7 +290,6 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
|
||||||
|
|
||||||
var state service.AccountQuotaState
|
var state service.AccountQuotaState
|
||||||
if rows.Next() {
|
if rows.Next() {
|
||||||
@@ -299,18 +298,36 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
|
|||||||
&state.DailyUsed, &state.DailyLimit,
|
&state.DailyUsed, &state.DailyLimit,
|
||||||
&state.WeeklyUsed, &state.WeeklyLimit,
|
&state.WeeklyUsed, &state.WeeklyLimit,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
_ = rows.Close()
|
||||||
return nil, service.ErrAccountNotFound
|
return nil, service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit {
|
// 必须在执行下一条 SQL 前显式关闭 rows:pq 驱动在同一连接上
|
||||||
|
// 不允许前一条查询的结果集未耗尽时启动新查询,否则会返回
|
||||||
|
// "unexpected Parse response" 错误。
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 任意维度额度在本次递增中从"未超"跨越到"已超"时,必须刷新调度快照,
|
||||||
|
// 否则 Redis 中缓存的 Account 仍显示旧的 used 值,后续请求会继续选中本账号,
|
||||||
|
// 最终观察到 daily_used / weekly_used 大幅超过配置的 limit。
|
||||||
|
// 对于日/周额度,即使本次触发了周期重置(pre=0、post=amount),
|
||||||
|
// 判定式 (post-amount) < limit 同样成立,逻辑与总额度保持一致。
|
||||||
|
crossedTotal := state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit
|
||||||
|
crossedDaily := state.DailyLimit > 0 && state.DailyUsed >= state.DailyLimit && (state.DailyUsed-amount) < state.DailyLimit
|
||||||
|
crossedWeekly := state.WeeklyLimit > 0 && state.WeeklyUsed >= state.WeeklyLimit && (state.WeeklyUsed-amount) < state.WeeklyLimit
|
||||||
|
if crossedTotal || crossedDaily || crossedWeekly {
|
||||||
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||||
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -199,6 +199,94 @@ func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
|||||||
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_EnqueuesSchedulerOutboxOnQuotaCrossing(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
newFixture := func(t *testing.T, extra map[string]any) (int64, int64) {
|
||||||
|
t.Helper()
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-outbox-user-%d-%s@example.com", time.Now().UnixNano(), uuid.NewString()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-outbox-" + uuid.NewString(),
|
||||||
|
Name: "billing-outbox",
|
||||||
|
})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{
|
||||||
|
Name: "usage-billing-outbox-" + uuid.NewString(),
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
Extra: extra,
|
||||||
|
})
|
||||||
|
return apiKey.ID, account.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
outboxCountFor := func(t *testing.T, accountID int64) int {
|
||||||
|
t.Helper()
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx,
|
||||||
|
"SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1 AND account_id = $2",
|
||||||
|
service.SchedulerOutboxEventAccountChanged, accountID,
|
||||||
|
).Scan(&count))
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("daily_first_crossing_enqueues", func(t *testing.T) {
|
||||||
|
apiKeyID, accountID := newFixture(t, map[string]any{
|
||||||
|
"quota_daily_limit": 10.0,
|
||||||
|
})
|
||||||
|
// 第一次低于日限额:不应入队 outbox
|
||||||
|
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
AccountType: service.AccountTypeAPIKey,
|
||||||
|
AccountQuotaCost: 4,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 0, outboxCountFor(t, accountID), "below limit should not enqueue")
|
||||||
|
|
||||||
|
// 第二次跨越日限额:应入队一次 outbox
|
||||||
|
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
AccountType: service.AccountTypeAPIKey,
|
||||||
|
AccountQuotaCost: 8,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, outboxCountFor(t, accountID), "crossing daily limit should enqueue once")
|
||||||
|
|
||||||
|
// 再次递增(已超):不应重复入队
|
||||||
|
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
AccountType: service.AccountTypeAPIKey,
|
||||||
|
AccountQuotaCost: 2,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, outboxCountFor(t, accountID), "subsequent increments beyond limit should not re-enqueue")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("weekly_first_crossing_enqueues", func(t *testing.T) {
|
||||||
|
apiKeyID, accountID := newFixture(t, map[string]any{
|
||||||
|
"quota_weekly_limit": 10.0,
|
||||||
|
})
|
||||||
|
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
AccountType: service.AccountTypeAPIKey,
|
||||||
|
AccountQuotaCost: 15, // 单次即跨越
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, outboxCountFor(t, accountID), "single-shot crossing weekly limit should enqueue once")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||||
|
|||||||
@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
|
|||||||
sql sqlExecutor
|
sql sqlExecutor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
|
// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储
|
||||||
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
|
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
|
||||||
return &userGroupRateRepository{sql: sqlDB}
|
return &userGroupRateRepository{sql: sqlDB}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByUserID 获取用户的所有专属分组倍率
|
// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||||
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
|
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
|
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL`
|
||||||
rows, err := r.sql.QueryContext(ctx, query, userID)
|
rows, err := r.sql.QueryContext(ctx, query, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByUserIDs 批量获取多个用户的专属分组倍率。
|
// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||||
// 返回结构:map[userID]map[groupID]rate
|
|
||||||
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
|
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
|
||||||
result := make(map[int64]map[int64]float64, len(userIDs))
|
result := make(map[int64]map[int64]float64, len(userIDs))
|
||||||
if len(userIDs) == 0 {
|
if len(userIDs) == 0 {
|
||||||
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
|||||||
rows, err := r.sql.QueryContext(ctx, `
|
rows, err := r.sql.QueryContext(ctx, `
|
||||||
SELECT user_id, group_id, rate_multiplier
|
SELECT user_id, group_id, rate_multiplier
|
||||||
FROM user_group_rate_multipliers
|
FROM user_group_rate_multipliers
|
||||||
WHERE user_id = ANY($1)
|
WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL
|
||||||
`, pq.Array(uniqueIDs))
|
`, pq.Array(uniqueIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
|
||||||
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
|
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override
|
||||||
FROM user_group_rate_multipliers ugr
|
FROM user_group_rate_multipliers ugr
|
||||||
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
|
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
|
||||||
WHERE ugr.group_id = $1
|
WHERE ugr.group_id = $1
|
||||||
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
|
|||||||
var result []service.UserGroupRateEntry
|
var result []service.UserGroupRateEntry
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var entry service.UserGroupRateEntry
|
var entry service.UserGroupRateEntry
|
||||||
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
|
var rate sql.NullFloat64
|
||||||
|
var rpm sql.NullInt32
|
||||||
|
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if rate.Valid {
|
||||||
|
v := rate.Float64
|
||||||
|
entry.RateMultiplier = &v
|
||||||
|
}
|
||||||
|
if rpm.Valid {
|
||||||
|
v := int(rpm.Int32)
|
||||||
|
entry.RPMOverride = &v
|
||||||
|
}
|
||||||
result = append(result, entry)
|
result = append(result, entry)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
|
||||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||||
var rate float64
|
var rate sql.NullFloat64
|
||||||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
|
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &rate, nil
|
if !rate.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
v := rate.Float64
|
||||||
|
return &v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
|
||||||
|
func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) {
|
||||||
|
query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||||
|
var rpm sql.NullInt32
|
||||||
|
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !rpm.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
v := int(rpm.Int32)
|
||||||
|
return &v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
|
||||||
|
// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
|
||||||
|
// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
|
||||||
|
// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
|
||||||
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
|
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
|
||||||
if len(rates) == 0 {
|
if len(rates) == 0 {
|
||||||
// 如果传入空 map,删除该用户的所有专属倍率
|
if _, err := r.sql.ExecContext(ctx, `
|
||||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
UPDATE user_group_rate_multipliers
|
||||||
|
SET rate_multiplier = NULL, updated_at = NOW()
|
||||||
|
WHERE user_id = $1
|
||||||
|
`, userID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := r.sql.ExecContext(ctx,
|
||||||
|
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`,
|
||||||
|
userID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 分离需要删除和需要 upsert 的记录
|
var clearGroupIDs []int64
|
||||||
var toDelete []int64
|
|
||||||
upsertGroupIDs := make([]int64, 0, len(rates))
|
upsertGroupIDs := make([]int64, 0, len(rates))
|
||||||
upsertRates := make([]float64, 0, len(rates))
|
upsertRates := make([]float64, 0, len(rates))
|
||||||
for groupID, rate := range rates {
|
for groupID, rate := range rates {
|
||||||
if rate == nil {
|
if rate == nil {
|
||||||
toDelete = append(toDelete, groupID)
|
clearGroupIDs = append(clearGroupIDs, groupID)
|
||||||
} else {
|
} else {
|
||||||
upsertGroupIDs = append(upsertGroupIDs, groupID)
|
upsertGroupIDs = append(upsertGroupIDs, groupID)
|
||||||
upsertRates = append(upsertRates, *rate)
|
upsertRates = append(upsertRates, *rate)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 删除指定的记录
|
if len(clearGroupIDs) > 0 {
|
||||||
if len(toDelete) > 0 {
|
if _, err := r.sql.ExecContext(ctx, `
|
||||||
|
UPDATE user_group_rate_multipliers
|
||||||
|
SET rate_multiplier = NULL, updated_at = NOW()
|
||||||
|
WHERE user_id = $1 AND group_id = ANY($2)
|
||||||
|
`, userID, pq.Array(clearGroupIDs)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if _, err := r.sql.ExecContext(ctx,
|
if _, err := r.sql.ExecContext(ctx,
|
||||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
|
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`,
|
||||||
userID, pq.Array(toDelete)); err != nil {
|
userID, pq.Array(clearGroupIDs)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upsert 记录
|
|
||||||
now := time.Now()
|
|
||||||
if len(upsertGroupIDs) > 0 {
|
if len(upsertGroupIDs) > 0 {
|
||||||
|
now := time.Now()
|
||||||
_, err := r.sql.ExecContext(ctx, `
|
_, err := r.sql.ExecContext(ctx, `
|
||||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||||
SELECT
|
SELECT
|
||||||
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
|
// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
|
||||||
|
// 语义:
|
||||||
|
// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
|
||||||
|
// - 出现的用户行:upsert rate_multiplier。
|
||||||
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
|
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
|
||||||
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
|
keepUserIDs := make([]int64, 0, len(entries))
|
||||||
|
for _, e := range entries {
|
||||||
|
keepUserIDs = append(keepUserIDs, e.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 未在 entries 列表中的行:清空 rate_multiplier。
|
||||||
|
if len(keepUserIDs) == 0 {
|
||||||
|
if _, err := r.sql.ExecContext(ctx, `
|
||||||
|
UPDATE user_group_rate_multipliers
|
||||||
|
SET rate_multiplier = NULL, updated_at = NOW()
|
||||||
|
WHERE group_id = $1
|
||||||
|
`, groupID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := r.sql.ExecContext(ctx, `
|
||||||
|
UPDATE user_group_rate_multipliers
|
||||||
|
SET rate_multiplier = NULL, updated_at = NOW()
|
||||||
|
WHERE group_id = $1 AND user_id <> ALL($2)
|
||||||
|
`, groupID, pq.Array(keepUserIDs)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清空后若整行 NULL 则删除。
|
||||||
|
if _, err := r.sql.ExecContext(ctx, `
|
||||||
|
DELETE FROM user_group_rate_multipliers
|
||||||
|
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||||
|
`, groupID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(entries) == 0 {
|
if len(entries) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
userIDs := make([]int64, len(entries))
|
userIDs := make([]int64, len(entries))
|
||||||
rates := make([]float64, len(entries))
|
rates := make([]float64, len(entries))
|
||||||
for i, e := range entries {
|
for i, e := range entries {
|
||||||
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
|
||||||
|
// 语义:
|
||||||
|
// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
|
||||||
|
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
|
||||||
|
func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error {
|
||||||
|
keepUserIDs := make([]int64, 0, len(entries))
|
||||||
|
var clearUserIDs []int64
|
||||||
|
upsertUserIDs := make([]int64, 0, len(entries))
|
||||||
|
upsertValues := make([]int32, 0, len(entries))
|
||||||
|
for _, e := range entries {
|
||||||
|
keepUserIDs = append(keepUserIDs, e.UserID)
|
||||||
|
if e.RPMOverride == nil {
|
||||||
|
clearUserIDs = append(clearUserIDs, e.UserID)
|
||||||
|
} else {
|
||||||
|
upsertUserIDs = append(upsertUserIDs, e.UserID)
|
||||||
|
upsertValues = append(upsertValues, int32(*e.RPMOverride))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 未在 entries 列表中的行:清空 rpm_override。
|
||||||
|
if len(keepUserIDs) == 0 {
|
||||||
|
if _, err := r.sql.ExecContext(ctx, `
|
||||||
|
UPDATE user_group_rate_multipliers
|
||||||
|
SET rpm_override = NULL, updated_at = NOW()
|
||||||
|
WHERE group_id = $1
|
||||||
|
`, groupID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := r.sql.ExecContext(ctx, `
|
||||||
|
UPDATE user_group_rate_multipliers
|
||||||
|
SET rpm_override = NULL, updated_at = NOW()
|
||||||
|
WHERE group_id = $1 AND user_id <> ALL($2)
|
||||||
|
`, groupID, pq.Array(keepUserIDs)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 显式 clear 的行。
|
||||||
|
if len(clearUserIDs) > 0 {
|
||||||
|
if _, err := r.sql.ExecContext(ctx, `
|
||||||
|
UPDATE user_group_rate_multipliers
|
||||||
|
SET rpm_override = NULL, updated_at = NOW()
|
||||||
|
WHERE group_id = $1 AND user_id = ANY($2)
|
||||||
|
`, groupID, pq.Array(clearUserIDs)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清空后若整行 NULL 则删除。
|
||||||
|
if _, err := r.sql.ExecContext(ctx, `
|
||||||
|
DELETE FROM user_group_rate_multipliers
|
||||||
|
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||||
|
`, groupID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(upsertUserIDs) > 0 {
|
||||||
|
now := time.Now()
|
||||||
|
_, err := r.sql.ExecContext(ctx, `
|
||||||
|
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
|
||||||
|
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
|
||||||
|
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
|
||||||
|
ON CONFLICT (user_id, group_id)
|
||||||
|
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
|
||||||
|
`, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
|
||||||
|
func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
|
||||||
|
if _, err := r.sql.ExecContext(ctx, `
|
||||||
|
UPDATE user_group_rate_multipliers
|
||||||
|
SET rpm_override = NULL, updated_at = NOW()
|
||||||
|
WHERE group_id = $1
|
||||||
|
`, groupID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := r.sql.ExecContext(ctx, `
|
||||||
|
DELETE FROM user_group_rate_multipliers
|
||||||
|
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||||
|
`, groupID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByGroupID 删除指定分组的所有用户专属条目
|
||||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteByUserID 删除指定用户的所有专属倍率
|
// DeleteByUserID 删除指定用户的所有专属条目
|
||||||
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
|||||||
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
||||||
SetNillableLastLoginAt(userIn.LastLoginAt).
|
SetNillableLastLoginAt(userIn.LastLoginAt).
|
||||||
SetNillableLastActiveAt(userIn.LastActiveAt).
|
SetNillableLastActiveAt(userIn.LastActiveAt).
|
||||||
|
SetRpmLimit(userIn.RPMLimit).
|
||||||
Save(txCtx)
|
Save(txCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||||
@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
|
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
|
||||||
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
|
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
|
||||||
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
|
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
|
||||||
SetTotalRecharged(userIn.TotalRecharged)
|
SetTotalRecharged(userIn.TotalRecharged).
|
||||||
|
SetRpmLimit(userIn.RPMLimit)
|
||||||
if userIn.SignupSource != "" {
|
if userIn.SignupSource != "" {
|
||||||
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
|
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
|
||||||
}
|
}
|
||||||
|
|||||||
108
backend/internal/repository/user_rpm_cache.go
Normal file
108
backend/internal/repository/user_rpm_cache.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 用户/分组级 RPM 计数器 Redis 实现。
|
||||||
|
//
|
||||||
|
// 设计说明:
|
||||||
|
// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
|
||||||
|
// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
|
||||||
|
// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
|
||||||
|
// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
|
||||||
|
// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
|
||||||
|
const (
|
||||||
|
userGroupRPMKeyPrefix = "rpm:ug:"
|
||||||
|
userRPMKeyPrefix = "rpm:u:"
|
||||||
|
|
||||||
|
userRPMKeyTTL = 120 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type userRPMCacheImpl struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
|
||||||
|
func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache {
|
||||||
|
return &userRPMCacheImpl{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
// minuteTS 获取当前 Redis 服务端分钟时间戳。
|
||||||
|
func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) {
|
||||||
|
t, err := c.rdb.Time(ctx).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("redis TIME: %w", err)
|
||||||
|
}
|
||||||
|
return t.Unix() / 60, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomicIncr 原子 INCR+EXPIRE。
|
||||||
|
func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) {
|
||||||
|
pipe := c.rdb.TxPipeline()
|
||||||
|
incr := pipe.Incr(ctx, key)
|
||||||
|
pipe.Expire(ctx, key, userRPMKeyTTL)
|
||||||
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
|
return 0, fmt.Errorf("user rpm increment: %w", err)
|
||||||
|
}
|
||||||
|
return int(incr.Val()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
|
||||||
|
func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
|
||||||
|
minute, err := c.minuteTS(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
|
||||||
|
return c.atomicIncr(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementUserRPM 递增用户分钟计数。
|
||||||
|
func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) {
|
||||||
|
minute, err := c.minuteTS(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
|
||||||
|
return c.atomicIncr(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
|
||||||
|
func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
|
||||||
|
minute, err := c.minuteTS(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
|
||||||
|
val, err := c.rdb.Get(ctx, key).Int()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("user group rpm get: %w", err)
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
|
||||||
|
func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) {
|
||||||
|
minute, err := c.minuteTS(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
|
||||||
|
val, err := c.rdb.Get(ctx, key).Int()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("user rpm get: %w", err)
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
@@ -98,10 +98,12 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAPIKeyCache,
|
NewAPIKeyCache,
|
||||||
NewTempUnschedCache,
|
NewTempUnschedCache,
|
||||||
NewTimeoutCounterCache,
|
NewTimeoutCounterCache,
|
||||||
|
NewOpenAI403CounterCache,
|
||||||
NewInternal500CounterCache,
|
NewInternal500CounterCache,
|
||||||
ProvideConcurrencyCache,
|
ProvideConcurrencyCache,
|
||||||
ProvideSessionLimitCache,
|
ProvideSessionLimitCache,
|
||||||
NewRPMCache,
|
NewRPMCache,
|
||||||
|
NewUserRPMCache,
|
||||||
NewUserMsgQueueCache,
|
NewUserMsgQueueCache,
|
||||||
NewDashboardCache,
|
NewDashboardCache,
|
||||||
NewEmailCache,
|
NewEmailCache,
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"role": "user",
|
"role": "user",
|
||||||
"balance": 12.5,
|
"balance": 12.5,
|
||||||
"concurrency": 5,
|
"concurrency": 5,
|
||||||
|
"rpm_limit": 0,
|
||||||
"status": "active",
|
"status": "active",
|
||||||
"allowed_groups": null,
|
"allowed_groups": null,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
@@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"fallback_group_id_on_invalid_request": null,
|
"fallback_group_id_on_invalid_request": null,
|
||||||
"require_oauth_only": false,
|
"require_oauth_only": false,
|
||||||
"require_privacy_set": false,
|
"require_privacy_set": false,
|
||||||
|
"rpm_limit": 0,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
"updated_at": "2025-01-02T03:04:05Z"
|
"updated_at": "2025-01-02T03:04:05Z"
|
||||||
}
|
}
|
||||||
@@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"force_email_on_third_party_signup": false,
|
"force_email_on_third_party_signup": false,
|
||||||
"default_concurrency": 5,
|
"default_concurrency": 5,
|
||||||
"default_balance": 1.25,
|
"default_balance": 1.25,
|
||||||
|
"default_user_rpm_limit": 0,
|
||||||
"default_subscriptions": [],
|
"default_subscriptions": [],
|
||||||
"enable_model_fallback": false,
|
"enable_model_fallback": false,
|
||||||
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
|
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
|
||||||
@@ -892,6 +895,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"custom_endpoints": [],
|
"custom_endpoints": [],
|
||||||
"default_concurrency": 0,
|
"default_concurrency": 0,
|
||||||
"default_balance": 0,
|
"default_balance": 0,
|
||||||
|
"default_user_rpm_limit": 0,
|
||||||
"default_subscriptions": [],
|
"default_subscriptions": [],
|
||||||
"enable_model_fallback": false,
|
"enable_model_fallback": false,
|
||||||
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
|
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
|
||||||
@@ -1090,7 +1094,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
|
|||||||
@@ -224,6 +224,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||||
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
|
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
|
||||||
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
|
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
|
||||||
|
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
|
||||||
|
|
||||||
// User attribute values
|
// User attribute values
|
||||||
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
||||||
@@ -247,6 +248,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
||||||
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
|
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
|
||||||
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
|
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
|
||||||
|
groups.PUT("/:id/rpm-overrides", h.Admin.Group.BatchSetGroupRPMOverrides)
|
||||||
|
groups.DELETE("/:id/rpm-overrides", h.Admin.Group.ClearGroupRPMOverrides)
|
||||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -930,10 +930,8 @@ func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapabilit
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
switch capability {
|
switch capability {
|
||||||
case OpenAIImagesCapabilityBasic:
|
case OpenAIImagesCapabilityBasic, OpenAIImagesCapabilityNative:
|
||||||
return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey
|
return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey
|
||||||
case OpenAIImagesCapabilityNative:
|
|
||||||
return a.Type == AccountTypeAPIKey
|
|
||||||
default:
|
default:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -1138,7 +1137,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via ChatGPT backend API.
|
// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via Codex /responses API.
|
||||||
func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
|
func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
|
||||||
authToken := account.GetOpenAIAccessToken()
|
authToken := account.GetOpenAIAccessToken()
|
||||||
if authToken == "" {
|
if authToken == "" {
|
||||||
@@ -1153,69 +1152,46 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
|
|||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
|
|
||||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
|
s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Initializing ChatGPT backend...\n"})
|
s.sendEvent(c, TestEvent{Type: "content", Text: "Calling Codex /responses image tool...\n"})
|
||||||
|
|
||||||
// Build headers (replicating buildOpenAIBackendAPIHeaders logic)
|
parsed := &OpenAIImagesRequest{
|
||||||
headers := buildOpenAIBackendAPIHeadersForTest(ctx, account, authToken, s.accountRepo)
|
Endpoint: openAIImagesGenerationsEndpoint,
|
||||||
|
Model: strings.TrimSpace(modelID),
|
||||||
|
Prompt: prompt,
|
||||||
|
}
|
||||||
|
applyOpenAIImagesDefaults(parsed)
|
||||||
|
|
||||||
|
responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, parsed.Model)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build image request: %s", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexAPIURL, bytes.NewReader(responsesBody))
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||||
|
}
|
||||||
|
req.Host = "chatgpt.com"
|
||||||
|
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||||
|
req.Header.Set("originator", "opencode")
|
||||||
|
if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
|
||||||
|
req.Header.Set("User-Agent", customUA)
|
||||||
|
} else {
|
||||||
|
req.Header.Set("User-Agent", codexCLIUserAgent)
|
||||||
|
}
|
||||||
|
if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
|
||||||
|
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||||
|
}
|
||||||
|
|
||||||
proxyURL := ""
|
proxyURL := ""
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||||
client, err := newOpenAIBackendAPIClient(proxyURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create client: %s", err.Error()))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Responses API request failed: %s", err.Error()))
|
||||||
}
|
|
||||||
|
|
||||||
// Bootstrap
|
|
||||||
if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil {
|
|
||||||
log.Printf("OpenAI image test bootstrap warning: %v", bootstrapErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch chat requirements
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Fetching chat requirements...\n"})
|
|
||||||
chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers)
|
|
||||||
if err != nil {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat requirements failed: %s", err.Error()))
|
|
||||||
}
|
|
||||||
if chatReqs.Arkose.Required {
|
|
||||||
return s.sendErrorAndEnd(c, "Unsupported challenge: arkose required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize and prepare conversation
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Preparing image conversation...\n"})
|
|
||||||
parentMessageID := uuid.NewString()
|
|
||||||
proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent"))
|
|
||||||
_ = initializeOpenAIImageConversation(ctx, client, headers)
|
|
||||||
conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, prompt, parentMessageID, chatReqs.Token, proofToken)
|
|
||||||
if err != nil {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation prepare failed: %s", err.Error()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build simplified conversation request (no file uploads)
|
|
||||||
convReq := buildOpenAIImageTestConversationRequest(prompt, parentMessageID)
|
|
||||||
convHeaders := cloneHTTPHeader(headers)
|
|
||||||
convHeaders.Set("Accept", "text/event-stream")
|
|
||||||
convHeaders.Set("Content-Type", "application/json")
|
|
||||||
convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token)
|
|
||||||
if conduitToken != "" {
|
|
||||||
convHeaders.Set("x-conduit-token", conduitToken)
|
|
||||||
}
|
|
||||||
if proofToken != "" {
|
|
||||||
convHeaders.Set("openai-sentinel-proof-token", proofToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Generating image...\n"})
|
|
||||||
|
|
||||||
resp, err := client.R().
|
|
||||||
SetContext(ctx).
|
|
||||||
DisableAutoReadResponse().
|
|
||||||
SetHeaders(headerToMap(convHeaders)).
|
|
||||||
SetBodyJsonMarshal(convReq).
|
|
||||||
Post(openAIChatGPTConversationURL)
|
|
||||||
if err != nil {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation request failed: %s", err.Error()))
|
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if resp != nil && resp.Body != nil {
|
if resp != nil && resp.Body != nil {
|
||||||
@@ -1223,49 +1199,35 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation API returned %d", resp.StatusCode))
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
message := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||||
|
if message == "" {
|
||||||
|
message = fmt.Sprintf("Responses API returned %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return s.sendErrorAndEnd(c, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
startTime := time.Now()
|
body, err := io.ReadAll(resp.Body)
|
||||||
conversationID, pointerInfos, _, _, err := readOpenAIImageConversationStream(resp, startTime)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read failed: %s", err.Error()))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read image response: %s", err.Error()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil)
|
results, _, _, _, _, err := collectOpenAIImagesFromResponsesBody(body)
|
||||||
if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) {
|
if err != nil {
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Waiting for image generation to complete...\n"})
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse image response: %s", err.Error()))
|
||||||
polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID)
|
|
||||||
if pollErr != nil {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Poll failed: %s", pollErr.Error()))
|
|
||||||
}
|
|
||||||
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers)
|
|
||||||
}
|
}
|
||||||
pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos)
|
if len(results) == 0 {
|
||||||
if len(pointerInfos) == 0 {
|
return s.sendErrorAndEnd(c, "No images returned from responses API")
|
||||||
return s.sendErrorAndEnd(c, "No images returned from conversation")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Downloading generated image...\n"})
|
for _, item := range results {
|
||||||
|
if item.RevisedPrompt != "" {
|
||||||
// Download and encode each image
|
s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt})
|
||||||
for _, pointer := range pointerInfos {
|
|
||||||
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
|
|
||||||
if err != nil {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Download URL fetch failed: %s", err.Error()))
|
|
||||||
}
|
|
||||||
data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
|
||||||
if err != nil {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Image download failed: %s", err.Error()))
|
|
||||||
}
|
|
||||||
b64 := base64.StdEncoding.EncodeToString(data)
|
|
||||||
mimeType := http.DetectContentType(data)
|
|
||||||
if pointer.Prompt != "" {
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: pointer.Prompt})
|
|
||||||
}
|
}
|
||||||
|
mimeType := openAIImageOutputMIMEType(item.OutputFormat)
|
||||||
s.sendEvent(c, TestEvent{
|
s.sendEvent(c, TestEvent{
|
||||||
Type: "image",
|
Type: "image",
|
||||||
ImageURL: "data:" + mimeType + ";base64," + b64,
|
ImageURL: "data:" + mimeType + ";base64," + item.Result,
|
||||||
MimeType: mimeType,
|
MimeType: mimeType,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1274,107 +1236,6 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildOpenAIBackendAPIHeadersForTest builds ChatGPT backend API headers for test purposes.
|
|
||||||
// Replicates the logic from OpenAIGatewayService.buildOpenAIBackendAPIHeaders without
|
|
||||||
// requiring the full gateway service dependency.
|
|
||||||
func buildOpenAIBackendAPIHeadersForTest(ctx context.Context, account *Account, token string, repo AccountRepository) http.Header {
|
|
||||||
// Ensure device and session IDs exist
|
|
||||||
deviceID := account.GetOpenAIDeviceID()
|
|
||||||
sessionID := account.GetOpenAISessionID()
|
|
||||||
if deviceID == "" || sessionID == "" {
|
|
||||||
updates := map[string]any{}
|
|
||||||
if deviceID == "" {
|
|
||||||
deviceID = uuid.NewString()
|
|
||||||
updates["openai_device_id"] = deviceID
|
|
||||||
}
|
|
||||||
if sessionID == "" {
|
|
||||||
sessionID = uuid.NewString()
|
|
||||||
updates["openai_session_id"] = sessionID
|
|
||||||
}
|
|
||||||
if account.Extra == nil {
|
|
||||||
account.Extra = map[string]any{}
|
|
||||||
}
|
|
||||||
for key, value := range updates {
|
|
||||||
account.Extra[key] = value
|
|
||||||
}
|
|
||||||
if repo != nil {
|
|
||||||
updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
_ = repo.UpdateExtra(updateCtx, account.ID, updates)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
headers := make(http.Header)
|
|
||||||
headers.Set("Authorization", "Bearer "+token)
|
|
||||||
headers.Set("Accept", "application/json")
|
|
||||||
headers.Set("Origin", "https://chatgpt.com")
|
|
||||||
headers.Set("Referer", "https://chatgpt.com/")
|
|
||||||
headers.Set("Sec-Fetch-Dest", "empty")
|
|
||||||
headers.Set("Sec-Fetch-Mode", "cors")
|
|
||||||
headers.Set("Sec-Fetch-Site", "same-origin")
|
|
||||||
headers.Set("User-Agent", openAIImageBackendUserAgent)
|
|
||||||
if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
|
|
||||||
headers.Set("User-Agent", customUA)
|
|
||||||
}
|
|
||||||
if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
|
|
||||||
headers.Set("chatgpt-account-id", chatgptAccountID)
|
|
||||||
}
|
|
||||||
if deviceID != "" {
|
|
||||||
headers.Set("oai-device-id", deviceID)
|
|
||||||
headers.Set("Cookie", "oai-did="+deviceID)
|
|
||||||
}
|
|
||||||
if sessionID != "" {
|
|
||||||
headers.Set("oai-session-id", sessionID)
|
|
||||||
}
|
|
||||||
return headers
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildOpenAIImageTestConversationRequest creates a simplified image generation conversation request.
|
|
||||||
func buildOpenAIImageTestConversationRequest(prompt, parentMessageID string) map[string]any {
|
|
||||||
promptText := strings.TrimSpace(prompt)
|
|
||||||
if promptText == "" {
|
|
||||||
promptText = "Generate an image."
|
|
||||||
}
|
|
||||||
metadata := map[string]any{
|
|
||||||
"developer_mode_connector_ids": []any{},
|
|
||||||
"selected_github_repos": []any{},
|
|
||||||
"selected_all_github_repos": false,
|
|
||||||
"system_hints": []string{"picture_v2"},
|
|
||||||
"serialization_metadata": map[string]any{
|
|
||||||
"custom_symbol_offsets": []any{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
message := map[string]any{
|
|
||||||
"id": uuid.NewString(),
|
|
||||||
"author": map[string]any{"role": "user"},
|
|
||||||
"content": map[string]any{
|
|
||||||
"content_type": "text",
|
|
||||||
"parts": []any{promptText},
|
|
||||||
},
|
|
||||||
"metadata": metadata,
|
|
||||||
"create_time": float64(time.Now().UnixMilli()) / 1000,
|
|
||||||
}
|
|
||||||
return map[string]any{
|
|
||||||
"action": "next",
|
|
||||||
"client_prepare_state": "sent",
|
|
||||||
"parent_message_id": parentMessageID,
|
|
||||||
"messages": []any{message},
|
|
||||||
"model": "auto",
|
|
||||||
"timezone_offset_min": openAITimezoneOffsetMinutes(),
|
|
||||||
"timezone": openAITimezoneName(),
|
|
||||||
"conversation_mode": map[string]any{"kind": "primary_assistant"},
|
|
||||||
"system_hints": []string{"picture_v2"},
|
|
||||||
"supports_buffering": true,
|
|
||||||
"supported_encodings": []string{"v1"},
|
|
||||||
"client_contextual_info": map[string]any{"app_name": "chatgpt.com"},
|
|
||||||
"force_nulligen": false,
|
|
||||||
"force_paragen": false,
|
|
||||||
"force_paragen_model_slug": "",
|
|
||||||
"force_rate_limit": false,
|
|
||||||
"websocket_request_id": uuid.NewString(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||||
eventJSON, _ := json.Marshal(event)
|
eventJSON, _ := json.Marshal(event)
|
||||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/event-stream"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" +
|
||||||
|
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000006,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 53,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token-123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, rec.Body.String(), "Calling Codex /responses image tool")
|
||||||
|
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
|
||||||
|
require.Contains(t, rec.Body.String(), "\"success\":true")
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,6 +33,7 @@ type AdminService interface {
|
|||||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
|
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
|
||||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||||
|
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
|
||||||
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||||
// codeType is optional - pass empty string to return all types.
|
// codeType is optional - pass empty string to return all types.
|
||||||
// Also returns totalRecharged (sum of all positive balance top-ups).
|
// Also returns totalRecharged (sum of all positive balance top-ups).
|
||||||
@@ -50,6 +52,8 @@ type AdminService interface {
|
|||||||
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||||
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
|
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
|
||||||
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||||
|
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
|
||||||
|
BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
|
||||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||||
|
|
||||||
// API Key management (admin)
|
// API Key management (admin)
|
||||||
@@ -114,6 +118,7 @@ type CreateUserInput struct {
|
|||||||
Notes string
|
Notes string
|
||||||
Balance float64
|
Balance float64
|
||||||
Concurrency int
|
Concurrency int
|
||||||
|
RPMLimit int
|
||||||
AllowedGroups []int64
|
AllowedGroups []int64
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,6 +129,7 @@ type UpdateUserInput struct {
|
|||||||
Notes *string
|
Notes *string
|
||||||
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
||||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||||
|
RPMLimit *int // 使用指针区分"未提供"和"设置为0"
|
||||||
Status string
|
Status string
|
||||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||||
// GroupRates 用户专属分组倍率配置
|
// GroupRates 用户专属分组倍率配置
|
||||||
@@ -199,6 +205,8 @@ type CreateGroupInput struct {
|
|||||||
RequireOAuthOnly bool
|
RequireOAuthOnly bool
|
||||||
RequirePrivacySet bool
|
RequirePrivacySet bool
|
||||||
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
|
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
|
||||||
|
// RPMLimit 分组 RPM 上限(0 = 不限制)
|
||||||
|
RPMLimit int
|
||||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||||
CopyAccountsFromGroupIDs []int64
|
CopyAccountsFromGroupIDs []int64
|
||||||
}
|
}
|
||||||
@@ -234,6 +242,8 @@ type UpdateGroupInput struct {
|
|||||||
RequireOAuthOnly *bool
|
RequireOAuthOnly *bool
|
||||||
RequirePrivacySet *bool
|
RequirePrivacySet *bool
|
||||||
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
|
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
|
||||||
|
// RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。
|
||||||
|
RPMLimit *int
|
||||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||||
CopyAccountsFromGroupIDs []int64
|
CopyAccountsFromGroupIDs []int64
|
||||||
}
|
}
|
||||||
@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct {
|
|||||||
MigratedKeys int64 // 迁移的 Key 数量
|
MigratedKeys int64 // 迁移的 Key 数量
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserRPMStatus describes a user's current per-minute RPM usage.
|
||||||
|
type UserRPMStatus struct {
|
||||||
|
UserRPMUsed int `json:"user_rpm_used"`
|
||||||
|
UserRPMLimit int `json:"user_rpm_limit"`
|
||||||
|
PerGroup []UserGroupRPMStatus `json:"per_group"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
|
||||||
|
type UserGroupRPMStatus struct {
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
GroupName string `json:"group_name"`
|
||||||
|
Used int `json:"used"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
Source string `json:"source"` // "group" | "override"
|
||||||
|
}
|
||||||
|
|
||||||
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||||
type BulkUpdateAccountsResult struct {
|
type BulkUpdateAccountsResult struct {
|
||||||
Success int `json:"success"`
|
Success int `json:"success"`
|
||||||
@@ -463,6 +489,8 @@ const (
|
|||||||
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
|
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available")
|
||||||
|
|
||||||
// adminServiceImpl implements AdminService
|
// adminServiceImpl implements AdminService
|
||||||
type adminServiceImpl struct {
|
type adminServiceImpl struct {
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
@@ -472,6 +500,7 @@ type adminServiceImpl struct {
|
|||||||
apiKeyRepo APIKeyRepository
|
apiKeyRepo APIKeyRepository
|
||||||
redeemCodeRepo RedeemCodeRepository
|
redeemCodeRepo RedeemCodeRepository
|
||||||
userGroupRateRepo UserGroupRateRepository
|
userGroupRateRepo UserGroupRateRepository
|
||||||
|
userRPMCache UserRPMCache
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
proxyProber ProxyExitInfoProber
|
proxyProber ProxyExitInfoProber
|
||||||
proxyLatencyCache ProxyLatencyCache
|
proxyLatencyCache ProxyLatencyCache
|
||||||
@@ -496,6 +525,7 @@ func NewAdminService(
|
|||||||
apiKeyRepo APIKeyRepository,
|
apiKeyRepo APIKeyRepository,
|
||||||
redeemCodeRepo RedeemCodeRepository,
|
redeemCodeRepo RedeemCodeRepository,
|
||||||
userGroupRateRepo UserGroupRateRepository,
|
userGroupRateRepo UserGroupRateRepository,
|
||||||
|
userRPMCache UserRPMCache,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
proxyProber ProxyExitInfoProber,
|
proxyProber ProxyExitInfoProber,
|
||||||
proxyLatencyCache ProxyLatencyCache,
|
proxyLatencyCache ProxyLatencyCache,
|
||||||
@@ -514,6 +544,7 @@ func NewAdminService(
|
|||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
redeemCodeRepo: redeemCodeRepo,
|
redeemCodeRepo: redeemCodeRepo,
|
||||||
userGroupRateRepo: userGroupRateRepo,
|
userGroupRateRepo: userGroupRateRepo,
|
||||||
|
userRPMCache: userRPMCache,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
proxyProber: proxyProber,
|
proxyProber: proxyProber,
|
||||||
proxyLatencyCache: proxyLatencyCache,
|
proxyLatencyCache: proxyLatencyCache,
|
||||||
@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
|
|||||||
Role: RoleUser, // Always create as regular user, never admin
|
Role: RoleUser, // Always create as regular user, never admin
|
||||||
Balance: input.Balance,
|
Balance: input.Balance,
|
||||||
Concurrency: input.Concurrency,
|
Concurrency: input.Concurrency,
|
||||||
|
RPMLimit: input.RPMLimit,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
AllowedGroups: input.AllowedGroups,
|
AllowedGroups: input.AllowedGroups,
|
||||||
}
|
}
|
||||||
@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
oldConcurrency := user.Concurrency
|
oldConcurrency := user.Concurrency
|
||||||
oldStatus := user.Status
|
oldStatus := user.Status
|
||||||
oldRole := user.Role
|
oldRole := user.Role
|
||||||
|
oldRPMLimit := user.RPMLimit
|
||||||
|
|
||||||
if input.Email != "" {
|
if input.Email != "" {
|
||||||
user.Email = input.Email
|
user.Email = input.Email
|
||||||
@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
user.Concurrency = *input.Concurrency
|
user.Concurrency = *input.Concurrency
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.RPMLimit != nil {
|
||||||
|
user.RPMLimit = *input.RPMLimit
|
||||||
|
}
|
||||||
|
|
||||||
if input.AllowedGroups != nil {
|
if input.AllowedGroups != nil {
|
||||||
user.AllowedGroups = *input.AllowedGroups
|
user.AllowedGroups = *input.AllowedGroups
|
||||||
}
|
}
|
||||||
@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.authCacheInvalidator != nil {
|
if s.authCacheInvalidator != nil {
|
||||||
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
// RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
|
||||||
|
// 不失效缓存会让修改在一个 L2 TTL 内失去效果。
|
||||||
|
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit {
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
|
|||||||
return keys, result.Total, nil
|
return keys, result.Total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) {
|
||||||
|
if s.userRPMCache == nil {
|
||||||
|
return nil, ErrRPMStatusUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := s.userRepo.GetByID(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
groupIDSet := make(map[int64]struct{})
|
||||||
|
for _, key := range keys {
|
||||||
|
if key.GroupID != nil && *key.GroupID > 0 {
|
||||||
|
groupIDSet[*key.GroupID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groupIDs := make([]int64, 0, len(groupIDSet))
|
||||||
|
for groupID := range groupIDSet {
|
||||||
|
groupIDs = append(groupIDs, groupID)
|
||||||
|
}
|
||||||
|
sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] })
|
||||||
|
|
||||||
|
var perGroup []UserGroupRPMStatus
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID)
|
||||||
|
if getErr != nil {
|
||||||
|
logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := UserGroupRPMStatus{
|
||||||
|
GroupID: groupID,
|
||||||
|
Used: used,
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.groupRepo != nil {
|
||||||
|
if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil {
|
||||||
|
entry.GroupName = group.Name
|
||||||
|
entry.Limit = group.RPMLimit
|
||||||
|
entry.Source = "group"
|
||||||
|
} else if groupErr != nil {
|
||||||
|
logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.userGroupRateRepo != nil {
|
||||||
|
override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID)
|
||||||
|
if overrideErr != nil {
|
||||||
|
logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr)
|
||||||
|
} else if override != nil {
|
||||||
|
entry.Limit = *override
|
||||||
|
entry.Source = "override"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
perGroup = append(perGroup, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserRPMStatus{
|
||||||
|
UserRPMUsed: userRPMUsed,
|
||||||
|
UserRPMLimit: user.RPMLimit,
|
||||||
|
PerGroup: perGroup,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
||||||
// Return mock data for now
|
// Return mock data for now
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
RequirePrivacySet: input.RequirePrivacySet,
|
RequirePrivacySet: input.RequirePrivacySet,
|
||||||
DefaultMappedModel: input.DefaultMappedModel,
|
DefaultMappedModel: input.DefaultMappedModel,
|
||||||
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
|
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
|
||||||
|
RPMLimit: input.RPMLimit,
|
||||||
}
|
}
|
||||||
sanitizeGroupMessagesDispatchFields(group)
|
sanitizeGroupMessagesDispatchFields(group)
|
||||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||||
@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if input.MessagesDispatchModelConfig != nil {
|
if input.MessagesDispatchModelConfig != nil {
|
||||||
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
|
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
|
||||||
}
|
}
|
||||||
|
if input.RPMLimit != nil {
|
||||||
|
group.RPMLimit = *input.RPMLimit
|
||||||
|
}
|
||||||
sanitizeGroupMessagesDispatchFields(group)
|
sanitizeGroupMessagesDispatchFields(group)
|
||||||
|
|
||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
|
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
|
||||||
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||||
// 去重源分组 IDs
|
// 去重源分组 IDs
|
||||||
@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.authCacheInvalidator != nil {
|
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
|
||||||
}
|
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
|
|||||||
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
|
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
|
||||||
|
if s.userGroupRateRepo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
|
||||||
|
if s.userGroupRateRepo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.RPMOverride != nil && *e.RPMOverride < 0 {
|
||||||
|
return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,8 +5,10 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
|
|||||||
syncedGroupID int64
|
syncedGroupID int64
|
||||||
syncedEntries []GroupRateMultiplierInput
|
syncedEntries []GroupRateMultiplierInput
|
||||||
syncGroupErr error
|
syncGroupErr error
|
||||||
|
|
||||||
|
rpmSyncedGroupID int64
|
||||||
|
rpmSyncedEntries []GroupRPMOverrideInput
|
||||||
|
rpmSyncErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
|
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
|
||||||
@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
|
|||||||
panic("unexpected GetByUserAndGroup call")
|
panic("unexpected GetByUserAndGroup call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
|
||||||
|
panic("unexpected GetRPMOverrideByUserAndGroup call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||||
if s.getByGroupIDErr != nil {
|
if s.getByGroupIDErr != nil {
|
||||||
return nil, s.getByGroupIDErr
|
return nil, s.getByGroupIDErr
|
||||||
@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
|
|||||||
return s.syncGroupErr
|
return s.syncGroupErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
|
||||||
|
s.rpmSyncedGroupID = groupID
|
||||||
|
s.rpmSyncedEntries = entries
|
||||||
|
return s.rpmSyncErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
|
||||||
|
panic("unexpected ClearGroupRPMOverrides call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
|
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||||
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
|
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
|
||||||
return s.deleteByGroupErr
|
return s.deleteByGroupErr
|
||||||
@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
|||||||
repo := &userGroupRateRepoStubForGroupRate{
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
getByGroupIDData: map[int64][]UserGroupRateEntry{
|
getByGroupIDData: map[int64][]UserGroupRateEntry{
|
||||||
10: {
|
10: {
|
||||||
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
|
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)},
|
||||||
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
|
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
|||||||
require.Len(t, entries, 2)
|
require.Len(t, entries, 2)
|
||||||
require.Equal(t, int64(1), entries[0].UserID)
|
require.Equal(t, int64(1), entries[0].UserID)
|
||||||
require.Equal(t, "alice", entries[0].UserName)
|
require.Equal(t, "alice", entries[0].UserName)
|
||||||
require.Equal(t, 1.5, entries[0].RateMultiplier)
|
require.NotNil(t, entries[0].RateMultiplier)
|
||||||
|
require.Equal(t, 1.5, *entries[0].RateMultiplier)
|
||||||
require.Equal(t, int64(2), entries[1].UserID)
|
require.Equal(t, int64(2), entries[1].UserID)
|
||||||
require.Equal(t, 0.8, entries[1].RateMultiplier)
|
require.NotNil(t, entries[1].RateMultiplier)
|
||||||
|
require.Equal(t, 0.8, *entries[1].RateMultiplier)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||||
@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
|
|||||||
require.Contains(t, err.Error(), "sync failed")
|
require.Contains(t, err.Error(), "sync failed")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) {
|
||||||
|
t.Run("syncs entries to repo", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
override := 20
|
||||||
|
entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}}
|
||||||
|
|
||||||
|
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(10), repo.rpmSyncedGroupID)
|
||||||
|
require.Equal(t, entries, repo.rpmSyncedEntries)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects negative override as bad request", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
negative := -1
|
||||||
|
|
||||||
|
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{
|
||||||
|
{UserID: 2, RPMOverride: &negative},
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, http.StatusBadRequest, infraerrors.Code(err))
|
||||||
|
require.Zero(t, repo.rpmSyncedGroupID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
|||||||
require.Nil(t, repo.updated.ImagePrice4K)
|
require.Nil(t, repo.updated.ImagePrice4K)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
|
||||||
|
existingGroup := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "existing-group",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
RPMLimit: 10,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
groupRepo: repo,
|
||||||
|
authCacheInvalidator: invalidator,
|
||||||
|
}
|
||||||
|
|
||||||
|
rpmLimit := 60
|
||||||
|
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||||
|
RPMLimit: &rpmLimit,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.Equal(t, 60, repo.updated.RPMLimit)
|
||||||
|
require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存")
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
|
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
|
||||||
repo := &groupRepoStubForAdmin{}
|
repo := &groupRepoStubForAdmin{}
|
||||||
svc := &adminServiceImpl{groupRepo: repo}
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|||||||
@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
|
|||||||
panic("unexpected GetByUserAndGroup call")
|
panic("unexpected GetByUserAndGroup call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
|
||||||
|
panic("unexpected GetRPMOverrideByUserAndGroup call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
|
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
|
||||||
panic("unexpected SyncUserGroupRates call")
|
panic("unexpected SyncUserGroupRates call")
|
||||||
}
|
}
|
||||||
@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
|
|||||||
panic("unexpected SyncGroupRateMultipliers call")
|
panic("unexpected SyncGroupRateMultipliers call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error {
|
||||||
|
panic("unexpected SyncGroupRPMOverrides call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
|
||||||
|
panic("unexpected ClearGroupRPMOverrides call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
|
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
|
||||||
panic("unexpected DeleteByGroupID call")
|
panic("unexpected DeleteByGroupID call")
|
||||||
}
|
}
|
||||||
|
|||||||
112
backend/internal/service/admin_service_rpm_status_test.go
Normal file
112
backend/internal/service/admin_service_rpm_status_test.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rpmStatusUserRepoStub struct {
|
||||||
|
UserRepository
|
||||||
|
user *User
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
|
||||||
|
return s.user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpmStatusAPIKeyRepoStub struct {
|
||||||
|
APIKeyRepository
|
||||||
|
keys []APIKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpmStatusGroupRepoStub struct {
|
||||||
|
GroupRepository
|
||||||
|
groups map[int64]*Group
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||||
|
return s.groups[id], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpmStatusRateRepoStub struct {
|
||||||
|
UserGroupRateRepository
|
||||||
|
overrides map[int64]*int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) {
|
||||||
|
return s.overrides[groupID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpmStatusCacheStub struct {
|
||||||
|
UserRPMCache
|
||||||
|
userUsed int
|
||||||
|
groupUsed map[int64]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) {
|
||||||
|
return s.groupUsed[groupID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) {
|
||||||
|
return s.userUsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) {
|
||||||
|
groupOneID := int64(1)
|
||||||
|
groupTwoID := int64(2)
|
||||||
|
override := 7
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: &rpmStatusUserRepoStub{user: &User{
|
||||||
|
ID: 42,
|
||||||
|
RPMLimit: 20,
|
||||||
|
}},
|
||||||
|
apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{
|
||||||
|
{ID: 100, UserID: 42, GroupID: &groupTwoID},
|
||||||
|
{ID: 101, UserID: 42, GroupID: &groupOneID},
|
||||||
|
{ID: 102, UserID: 42, GroupID: &groupTwoID},
|
||||||
|
{ID: 103, UserID: 42},
|
||||||
|
}},
|
||||||
|
groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{
|
||||||
|
groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10},
|
||||||
|
groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60},
|
||||||
|
}},
|
||||||
|
userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{
|
||||||
|
groupTwoID: &override,
|
||||||
|
}},
|
||||||
|
userRPMCache: &rpmStatusCacheStub{
|
||||||
|
userUsed: 5,
|
||||||
|
groupUsed: map[int64]int{
|
||||||
|
groupOneID: 3,
|
||||||
|
groupTwoID: 4,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := svc.GetUserRPMStatus(context.Background(), 42)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, &UserRPMStatus{
|
||||||
|
UserRPMUsed: 5,
|
||||||
|
UserRPMLimit: 20,
|
||||||
|
PerGroup: []UserGroupRPMStatus{
|
||||||
|
{GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"},
|
||||||
|
{GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"},
|
||||||
|
},
|
||||||
|
}, status)
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
|
||||||
|
// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
|
||||||
|
type rpmUserRepoStub struct {
|
||||||
|
*userRepoStub
|
||||||
|
lastUpdated *User
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error {
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := *user
|
||||||
|
s.lastUpdated = &clone
|
||||||
|
if s.userRepoStub != nil {
|
||||||
|
s.userRepoStub.user = &clone
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
|
||||||
|
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}}
|
||||||
|
repo := &rpmUserRepoStub{userRepoStub: base}
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: repo,
|
||||||
|
redeemCodeRepo: &redeemRepoStub{},
|
||||||
|
authCacheInvalidator: invalidator,
|
||||||
|
}
|
||||||
|
|
||||||
|
newRPM := 60
|
||||||
|
updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
|
||||||
|
RPMLimit: &newRPM,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updated)
|
||||||
|
require.Equal(t, 60, updated.RPMLimit)
|
||||||
|
require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) {
|
||||||
|
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}}
|
||||||
|
repo := &rpmUserRepoStub{userRepoStub: base}
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: repo,
|
||||||
|
redeemCodeRepo: &redeemRepoStub{},
|
||||||
|
authCacheInvalidator: invalidator,
|
||||||
|
}
|
||||||
|
|
||||||
|
newName := "new"
|
||||||
|
sameRPM := 10
|
||||||
|
_, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
|
||||||
|
Username: &newName,
|
||||||
|
RPMLimit: &sameRPM,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效")
|
||||||
|
}
|
||||||
@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct {
|
|||||||
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
|
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
|
||||||
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
|
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
|
||||||
TotalRecharged float64 `json:"total_recharged"`
|
TotalRecharged float64 `json:"total_recharged"`
|
||||||
|
|
||||||
|
// RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
|
||||||
|
RPMLimit int `json:"rpm_limit"`
|
||||||
|
|
||||||
|
// UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
|
||||||
|
// nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。
|
||||||
|
UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyAuthGroupSnapshot 分组快照
|
// APIKeyAuthGroupSnapshot 分组快照
|
||||||
@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct {
|
|||||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||||
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
|
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
|
||||||
|
|
||||||
|
// RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
|
||||||
|
RPMLimit int `json:"rpm_limit"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/dgraph-io/ristretto"
|
"github.com/dgraph-io/ristretto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
|
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
|
||||||
|
|
||||||
type apiKeyAuthCacheConfig struct {
|
type apiKeyAuthCacheConfig struct {
|
||||||
l1Size int
|
l1Size int
|
||||||
@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
|
|||||||
return nil, fmt.Errorf("get api key: %w", err)
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
}
|
}
|
||||||
apiKey.Key = key
|
apiKey.Key = key
|
||||||
snapshot := s.snapshotFromAPIKey(apiKey)
|
snapshot := s.snapshotFromAPIKey(ctx, apiKey)
|
||||||
if snapshot == nil {
|
if snapshot == nil {
|
||||||
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
||||||
}
|
}
|
||||||
@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
|
|||||||
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||||
if apiKey == nil || apiKey.User == nil {
|
if apiKey == nil || apiKey.User == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
|
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
|
||||||
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
|
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
|
||||||
TotalRecharged: apiKey.User.TotalRecharged,
|
TotalRecharged: apiKey.User.TotalRecharged,
|
||||||
|
RPMLimit: apiKey.User.RPMLimit,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。
|
||||||
|
if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil {
|
||||||
|
override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID)
|
||||||
|
if err == nil && override != nil {
|
||||||
|
snapshot.User.UserGroupRPMOverride = override
|
||||||
|
}
|
||||||
|
// 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询
|
||||||
|
}
|
||||||
if apiKey.Group != nil {
|
if apiKey.Group != nil {
|
||||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||||
ID: apiKey.Group.ID,
|
ID: apiKey.Group.ID,
|
||||||
@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
||||||
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
||||||
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
|
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
|
||||||
|
RPMLimit: apiKey.Group.RPMLimit,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return snapshot
|
return snapshot
|
||||||
@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
|
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
|
||||||
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
|
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
|
||||||
TotalRecharged: snapshot.User.TotalRecharged,
|
TotalRecharged: snapshot.User.TotalRecharged,
|
||||||
|
RPMLimit: snapshot.User.RPMLimit,
|
||||||
|
UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if snapshot.Group != nil {
|
if snapshot.Group != nil {
|
||||||
@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
||||||
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
||||||
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
|
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
|
||||||
|
RPMLimit: snapshot.Group.RPMLimit,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.compileAPIKeyIPRules(apiKey)
|
s.compileAPIKeyIPRules(apiKey)
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
snapshot := svc.snapshotFromAPIKey(apiKey)
|
snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey)
|
||||||
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
|
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
|
||||||
|
|
||||||
require.NotNil(t, roundTrip)
|
require.NotNil(t, roundTrip)
|
||||||
|
|||||||
@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
|||||||
|
|
||||||
grantPlan := s.resolveSignupGrantPlan(ctx, "email")
|
grantPlan := s.resolveSignupGrantPlan(ctx, "email")
|
||||||
|
|
||||||
|
// 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
|
||||||
|
var defaultRPMLimit int
|
||||||
|
if s.settingService != nil {
|
||||||
|
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// 创建用户
|
// 创建用户
|
||||||
user := &User{
|
user := &User{
|
||||||
Email: email,
|
Email: email,
|
||||||
@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
|||||||
Role: RoleUser,
|
Role: RoleUser,
|
||||||
Balance: grantPlan.Balance,
|
Balance: grantPlan.Balance,
|
||||||
Concurrency: grantPlan.Concurrency,
|
Concurrency: grantPlan.Concurrency,
|
||||||
|
RPMLimit: defaultRPMLimit,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
|||||||
|
|
||||||
signupSource := inferLegacySignupSource(email)
|
signupSource := inferLegacySignupSource(email)
|
||||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||||
|
var defaultRPMLimit int
|
||||||
|
if s.settingService != nil {
|
||||||
|
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
newUser := &User{
|
newUser := &User{
|
||||||
Email: email,
|
Email: email,
|
||||||
@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
|||||||
Role: RoleUser,
|
Role: RoleUser,
|
||||||
Balance: grantPlan.Balance,
|
Balance: grantPlan.Balance,
|
||||||
Concurrency: grantPlan.Concurrency,
|
Concurrency: grantPlan.Concurrency,
|
||||||
|
RPMLimit: defaultRPMLimit,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
SignupSource: signupSource,
|
SignupSource: signupSource,
|
||||||
}
|
}
|
||||||
@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
|||||||
|
|
||||||
signupSource := inferLegacySignupSource(email)
|
signupSource := inferLegacySignupSource(email)
|
||||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||||
|
var defaultRPMLimit int
|
||||||
|
if s.settingService != nil {
|
||||||
|
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
newUser := &User{
|
newUser := &User{
|
||||||
Email: email,
|
Email: email,
|
||||||
@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
|||||||
Role: RoleUser,
|
Role: RoleUser,
|
||||||
Balance: grantPlan.Balance,
|
Balance: grantPlan.Balance,
|
||||||
Concurrency: grantPlan.Concurrency,
|
Concurrency: grantPlan.Concurrency,
|
||||||
|
RPMLimit: defaultRPMLimit,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
SignupSource: signupSource,
|
SignupSource: signupSource,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ import (
|
|||||||
var (
|
var (
|
||||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||||
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
||||||
|
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
|
||||||
|
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
|
||||||
|
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
|
||||||
)
|
)
|
||||||
|
|
||||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||||
@@ -87,6 +90,8 @@ type BillingCacheService struct {
|
|||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
subRepo UserSubscriptionRepository
|
subRepo UserSubscriptionRepository
|
||||||
apiKeyRateLimitLoader apiKeyRateLimitLoader
|
apiKeyRateLimitLoader apiKeyRateLimitLoader
|
||||||
|
userRPMCache UserRPMCache
|
||||||
|
userGroupRateRepo UserGroupRateRepository
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
circuitBreaker *billingCircuitBreaker
|
circuitBreaker *billingCircuitBreaker
|
||||||
|
|
||||||
@@ -104,12 +109,22 @@ type BillingCacheService struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewBillingCacheService 创建计费缓存服务
|
// NewBillingCacheService 创建计费缓存服务
|
||||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
|
func NewBillingCacheService(
|
||||||
|
cache BillingCache,
|
||||||
|
userRepo UserRepository,
|
||||||
|
subRepo UserSubscriptionRepository,
|
||||||
|
apiKeyRepo APIKeyRepository,
|
||||||
|
userRPMCache UserRPMCache,
|
||||||
|
userGroupRateRepo UserGroupRateRepository,
|
||||||
|
cfg *config.Config,
|
||||||
|
) *BillingCacheService {
|
||||||
svc := &BillingCacheService{
|
svc := &BillingCacheService{
|
||||||
cache: cache,
|
cache: cache,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
subRepo: subRepo,
|
subRepo: subRepo,
|
||||||
apiKeyRateLimitLoader: apiKeyRepo,
|
apiKeyRateLimitLoader: apiKeyRepo,
|
||||||
|
userRPMCache: userRPMCache,
|
||||||
|
userGroupRateRepo: userGroupRateRepo,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||||
@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。
|
||||||
|
if err := s.checkRPM(ctx, user, group); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
|
||||||
|
//
|
||||||
|
// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
|
||||||
|
// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
|
||||||
|
// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
|
||||||
|
// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
|
||||||
|
//
|
||||||
|
// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
|
||||||
|
// Redis 故障一律 fail-open(打 warning,不阻塞业务)。
|
||||||
|
func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error {
|
||||||
|
if s == nil || s.userRPMCache == nil || user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 第一层:分组级检查(override 或 group.rpm_limit) ──
|
||||||
|
if group != nil {
|
||||||
|
// 解析 override:优先从 auth cache snapshot,nil 时回退 DB。
|
||||||
|
var override *int
|
||||||
|
if user.UserGroupRPMOverride != nil {
|
||||||
|
override = user.UserGroupRPMOverride
|
||||||
|
} else if s.userGroupRateRepo != nil {
|
||||||
|
dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.billing_cache",
|
||||||
|
"Warning: rpm override lookup failed for user=%d group=%d: %v",
|
||||||
|
user.ID, group.ID, err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
override = dbOverride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if override != nil {
|
||||||
|
// override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
|
||||||
|
if *override > 0 {
|
||||||
|
count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
|
||||||
|
if incErr != nil {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.billing_cache",
|
||||||
|
"Warning: rpm increment (override) failed for user=%d group=%d: %v",
|
||||||
|
user.ID, group.ID, incErr,
|
||||||
|
)
|
||||||
|
// fail-open
|
||||||
|
} else if count > *override {
|
||||||
|
return ErrGroupRPMExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。
|
||||||
|
} else if group.RPMLimit > 0 {
|
||||||
|
// 无 override,检查 group.rpm_limit。
|
||||||
|
count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.billing_cache",
|
||||||
|
"Warning: rpm increment (group) failed for user=%d group=%d: %v",
|
||||||
|
user.ID, group.ID, err,
|
||||||
|
)
|
||||||
|
// fail-open
|
||||||
|
} else if count > group.RPMLimit {
|
||||||
|
return ErrGroupRPMExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 第二层:用户级全局硬上限(始终生效) ──
|
||||||
|
if user.RPMLimit > 0 {
|
||||||
|
count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.billing_cache",
|
||||||
|
"Warning: rpm increment (user) failed for user=%d: %v",
|
||||||
|
user.ID, err,
|
||||||
|
)
|
||||||
|
return nil // fail-open
|
||||||
|
}
|
||||||
|
if count > user.RPMLimit {
|
||||||
|
return ErrUserRPMExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
253
backend/internal/service/billing_cache_service_rpm_test.go
Normal file
253
backend/internal/service/billing_cache_service_rpm_test.go
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
|
||||||
|
type userRPMCacheStub struct {
|
||||||
|
userGroupCalls int32
|
||||||
|
userCalls int32
|
||||||
|
|
||||||
|
userGroupCounts []int // 依次返回的计数值
|
||||||
|
userGroupErr error
|
||||||
|
userCounts []int
|
||||||
|
userErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
|
||||||
|
idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1
|
||||||
|
if s.userGroupErr != nil {
|
||||||
|
return 0, s.userGroupErr
|
||||||
|
}
|
||||||
|
if idx < len(s.userGroupCounts) {
|
||||||
|
return s.userGroupCounts[idx], nil
|
||||||
|
}
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) {
|
||||||
|
idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1
|
||||||
|
if s.userErr != nil {
|
||||||
|
return 0, s.userErr
|
||||||
|
}
|
||||||
|
if idx < len(s.userCounts) {
|
||||||
|
return s.userCounts[idx], nil
|
||||||
|
}
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
|
||||||
|
type rpmOverrideRepoStub struct {
|
||||||
|
UserGroupRateRepository
|
||||||
|
|
||||||
|
override *int
|
||||||
|
err error
|
||||||
|
calls int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
|
||||||
|
atomic.AddInt32(&s.calls, 1)
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return s.override, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService {
|
||||||
|
t.Helper()
|
||||||
|
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
|
||||||
|
// 我们只直接测 checkRPM。
|
||||||
|
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
|
||||||
|
t.Cleanup(svc.Stop)
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) {
|
||||||
|
override := 2
|
||||||
|
// user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰)
|
||||||
|
cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}}
|
||||||
|
repo := &rpmOverrideRepoStub{override: &override}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试
|
||||||
|
group := &Group{ID: 10, RPMLimit: 100}
|
||||||
|
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||||
|
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded)
|
||||||
|
|
||||||
|
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数")
|
||||||
|
// 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user
|
||||||
|
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用")
|
||||||
|
require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) {
|
||||||
|
override := 100 // override 很高
|
||||||
|
// user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3
|
||||||
|
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||||||
|
repo := &rpmOverrideRepoStub{override: &override}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100
|
||||||
|
group := &Group{ID: 10, RPMLimit: 100}
|
||||||
|
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||||
|
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) {
|
||||||
|
zero := 0
|
||||||
|
// user 计数: 依次返回 1..6
|
||||||
|
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}}
|
||||||
|
repo := &rpmOverrideRepoStub{override: &zero}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
user := &User{ID: 1, RPMLimit: 5}
|
||||||
|
group := &Group{ID: 10, RPMLimit: 100}
|
||||||
|
|
||||||
|
// override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1)
|
||||||
|
}
|
||||||
|
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded,
|
||||||
|
"override=0 跳过分组但 user 全局上限仍应生效")
|
||||||
|
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器")
|
||||||
|
require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) {
|
||||||
|
zero := 0
|
||||||
|
cache := &userRPMCacheStub{}
|
||||||
|
repo := &rpmOverrideRepoStub{override: &zero}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
user := &User{ID: 1, RPMLimit: 0} // user 也不限
|
||||||
|
group := &Group{ID: 10, RPMLimit: 100}
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||||
|
}
|
||||||
|
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数")
|
||||||
|
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) {
|
||||||
|
// user-group 计数: 5, 6;user 计数: 默认 1(不干扰)
|
||||||
|
cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}}
|
||||||
|
repo := &rpmOverrideRepoStub{override: nil}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超
|
||||||
|
group := &Group{ID: 10, RPMLimit: 5}
|
||||||
|
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超
|
||||||
|
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5
|
||||||
|
|
||||||
|
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls))
|
||||||
|
// 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user
|
||||||
|
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) {
|
||||||
|
cache := &userRPMCacheStub{userGroupCounts: []int{3}}
|
||||||
|
repo := &rpmOverrideRepoStub{err: errors.New("db down")}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
user := &User{ID: 1, RPMLimit: 0}
|
||||||
|
group := &Group{ID: 10, RPMLimit: 10}
|
||||||
|
|
||||||
|
// override 查询失败后应继续尝试 group 分支(不直接拒绝)
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||||
|
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
|
||||||
|
require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) {
|
||||||
|
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||||||
|
repo := &rpmOverrideRepoStub{override: nil}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
user := &User{ID: 1, RPMLimit: 2}
|
||||||
|
group := &Group{ID: 10, RPMLimit: 0} // 分组未设限
|
||||||
|
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||||
|
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded)
|
||||||
|
|
||||||
|
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键")
|
||||||
|
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) {
|
||||||
|
cache := &userRPMCacheStub{}
|
||||||
|
repo := &rpmOverrideRepoStub{override: nil}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
user := &User{ID: 1, RPMLimit: 0}
|
||||||
|
group := &Group{ID: 10, RPMLimit: 0}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||||
|
}
|
||||||
|
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
|
||||||
|
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) {
|
||||||
|
cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")}
|
||||||
|
repo := &rpmOverrideRepoStub{override: nil}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
user := &User{ID: 1, RPMLimit: 0}
|
||||||
|
group := &Group{ID: 10, RPMLimit: 5}
|
||||||
|
|
||||||
|
// Redis 故障时应 fail-open,不拒绝请求
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||||
|
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) {
|
||||||
|
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||||||
|
repo := &rpmOverrideRepoStub{}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
user := &User{ID: 1, RPMLimit: 2}
|
||||||
|
|
||||||
|
// 无 group(纯用户级限流场景),不应查询 rpm_override。
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
|
||||||
|
require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded)
|
||||||
|
|
||||||
|
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override")
|
||||||
|
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) {
|
||||||
|
cache := &userRPMCacheStub{}
|
||||||
|
repo := &rpmOverrideRepoStub{}
|
||||||
|
svc := newBillingServiceForRPM(t, cache, repo)
|
||||||
|
|
||||||
|
require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10}))
|
||||||
|
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
|
||||||
|
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
|
||||||
|
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls))
|
||||||
|
}
|
||||||
@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
|||||||
delay: 80 * time.Millisecond,
|
delay: 80 * time.Millisecond,
|
||||||
balance: 12.34,
|
balance: 12.34,
|
||||||
}
|
}
|
||||||
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
|
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
|
||||||
t.Cleanup(svc.Stop)
|
t.Cleanup(svc.Stop)
|
||||||
|
|
||||||
const goroutines = 16
|
const goroutines = 16
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
|
|||||||
|
|
||||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||||
cache := &billingCacheWorkerStub{}
|
cache := &billingCacheWorkerStub{}
|
||||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
|
||||||
t.Cleanup(svc.Stop)
|
t.Cleanup(svc.Stop)
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
|||||||
|
|
||||||
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
|
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
|
||||||
cache := &billingCacheWorkerStub{}
|
cache := &billingCacheWorkerStub{}
|
||||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
|
||||||
svc.Stop()
|
svc.Stop()
|
||||||
|
|
||||||
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
|
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
|
||||||
|
|||||||
@@ -217,6 +217,9 @@ func (s *BillingService) initFallbackPricing() {
|
|||||||
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
||||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||||
}
|
}
|
||||||
|
// GPT-5.5 暂无独立定价,回退到 GPT-5.4
|
||||||
|
s.fallbackPrices["gpt-5.5"] = s.fallbackPrices["gpt-5.4"]
|
||||||
|
|
||||||
s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{
|
s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{
|
||||||
InputPricePerToken: 7.5e-7,
|
InputPricePerToken: 7.5e-7,
|
||||||
OutputPricePerToken: 4.5e-6,
|
OutputPricePerToken: 4.5e-6,
|
||||||
@@ -288,6 +291,8 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
|||||||
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
|
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
|
||||||
normalized := normalizeCodexModel(modelLower)
|
normalized := normalizeCodexModel(modelLower)
|
||||||
switch normalized {
|
switch normalized {
|
||||||
|
case "gpt-5.5":
|
||||||
|
return s.fallbackPrices["gpt-5.5"]
|
||||||
case "gpt-5.4-mini":
|
case "gpt-5.4-mini":
|
||||||
return s.fallbackPrices["gpt-5.4-mini"]
|
return s.fallbackPrices["gpt-5.4-mini"]
|
||||||
case "gpt-5.4":
|
case "gpt-5.4":
|
||||||
@@ -637,7 +642,8 @@ func isOpenAIGPT54Model(model string) bool {
|
|||||||
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
|
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return normalizeCodexModel(trimmed) == "gpt-5.4"
|
normalized := normalizeCodexModel(trimmed)
|
||||||
|
return normalized == "gpt-5.4" || normalized == "gpt-5.5"
|
||||||
}
|
}
|
||||||
|
|
||||||
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
|
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
|
||||||
|
|||||||
@@ -170,9 +170,10 @@ const (
|
|||||||
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
|
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
|
||||||
|
|
||||||
// 默认配置
|
// 默认配置
|
||||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||||
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
|
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
|
||||||
|
SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制)
|
||||||
|
|
||||||
// 第三方认证来源默认授予配置
|
// 第三方认证来源默认授予配置
|
||||||
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
|
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
|
||||||
|
|||||||
@@ -59,6 +59,10 @@ type Group struct {
|
|||||||
DefaultMappedModel string
|
DefaultMappedModel string
|
||||||
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
|
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
|
||||||
|
|
||||||
|
// RPMLimit 分组级每分钟请求数上限(0 = 不限制)。
|
||||||
|
// 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。
|
||||||
|
RPMLimit int
|
||||||
|
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
|
||||||
|
|||||||
11
backend/internal/service/openai_403_counter.go
Normal file
11
backend/internal/service/openai_403_counter.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// OpenAI403CounterCache 追踪 OpenAI 账号连续 403 失败次数。
|
||||||
|
type OpenAI403CounterCache interface {
|
||||||
|
// IncrementOpenAI403Count 原子递增 403 计数并返回当前值。
|
||||||
|
IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error)
|
||||||
|
// ResetOpenAI403Count 成功后清零计数器。
|
||||||
|
ResetOpenAI403Count(ctx context.Context, accountID int64) error
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var codexModelMap = map[string]string{
|
var codexModelMap = map[string]string{
|
||||||
|
"gpt-5.5": "gpt-5.5",
|
||||||
"gpt-5.4": "gpt-5.4",
|
"gpt-5.4": "gpt-5.4",
|
||||||
"gpt-5.4-mini": "gpt-5.4-mini",
|
"gpt-5.4-mini": "gpt-5.4-mini",
|
||||||
"gpt-5.4-none": "gpt-5.4",
|
"gpt-5.4-none": "gpt-5.4",
|
||||||
@@ -207,6 +208,9 @@ func normalizeCodexModel(model string) string {
|
|||||||
|
|
||||||
normalized := strings.ToLower(modelID)
|
normalized := strings.ToLower(modelID)
|
||||||
|
|
||||||
|
if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") {
|
||||||
|
return "gpt-5.5"
|
||||||
|
}
|
||||||
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
|
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
|
||||||
return "gpt-5.4-mini"
|
return "gpt-5.4-mini"
|
||||||
}
|
}
|
||||||
|
|||||||
39
backend/internal/service/openai_gateway_403_reset_test.go
Normal file
39
backend/internal/service/openai_gateway_403_reset_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAI403CounterResetStub struct {
|
||||||
|
resetCalls []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAI403CounterResetStub) IncrementOpenAI403Count(context.Context, int64, int) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
|
||||||
|
s.resetCalls = append(s.resetCalls, accountID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
|
||||||
|
counter := &openAI403CounterResetStub{}
|
||||||
|
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
|
||||||
|
rateLimitSvc.SetOpenAI403CounterCache(counter)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
rateLimitService: rateLimitSvc,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{},
|
||||||
|
Account: &Account{ID: 777, Platform: PlatformOpenAI},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []int64{777}, counter.resetCalls)
|
||||||
|
}
|
||||||
@@ -1098,3 +1098,50 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.
|
|||||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) {
|
||||||
|
imagePrice := 0.02
|
||||||
|
groupID := int64(12)
|
||||||
|
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_image_per_request",
|
||||||
|
Model: "gpt-image-2",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 1110,
|
||||||
|
OutputTokens: 1756,
|
||||||
|
ImageOutputTokens: 1756,
|
||||||
|
},
|
||||||
|
ImageCount: 2,
|
||||||
|
ImageSize: "1K",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 1008,
|
||||||
|
GroupID: i64p(groupID),
|
||||||
|
Group: &Group{
|
||||||
|
ID: groupID,
|
||||||
|
RateMultiplier: 1.0,
|
||||||
|
ImagePrice1K: &imagePrice,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
User: &User{ID: 2008},
|
||||||
|
Account: &Account{ID: 3008},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||||
|
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||||
|
require.Equal(t, 2, usageRepo.lastLog.ImageCount)
|
||||||
|
require.InDelta(t, 0.04, usageRepo.lastLog.TotalCost, 1e-12)
|
||||||
|
require.InDelta(t, 0.04, usageRepo.lastLog.ActualCost, 1e-12)
|
||||||
|
require.InDelta(t, 0.0, usageRepo.lastLog.InputCost, 1e-12)
|
||||||
|
require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
|
||||||
|
require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4425,6 +4425,9 @@ type OpenAIRecordUsageInput struct {
|
|||||||
// RecordUsage records usage and deducts balance
|
// RecordUsage records usage and deducts balance
|
||||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||||
result := input.Result
|
result := input.Result
|
||||||
|
if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
|
||||||
|
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
||||||
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
||||||
@@ -4622,12 +4625,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
|
|||||||
serviceTier string,
|
serviceTier string,
|
||||||
) (*CostBreakdown, error) {
|
) (*CostBreakdown, error) {
|
||||||
if result != nil && result.ImageCount > 0 {
|
if result != nil && result.ImageCount > 0 {
|
||||||
if hasOpenAIImageUsageTokens(result) {
|
|
||||||
cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize)
|
|
||||||
if err == nil {
|
|
||||||
return cost, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
|
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
|
||||||
}
|
}
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
if s.resolver != nil && apiKey.Group != nil {
|
||||||
@@ -4646,32 +4643,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
|
|||||||
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) calculateOpenAIImageTokenCost(
|
|
||||||
ctx context.Context,
|
|
||||||
apiKey *APIKey,
|
|
||||||
billingModel string,
|
|
||||||
multiplier float64,
|
|
||||||
tokens UsageTokens,
|
|
||||||
serviceTier string,
|
|
||||||
sizeTier string,
|
|
||||||
) (*CostBreakdown, error) {
|
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
|
||||||
gid := apiKey.Group.ID
|
|
||||||
return s.billingService.CalculateCostUnified(CostInput{
|
|
||||||
Ctx: ctx,
|
|
||||||
Model: billingModel,
|
|
||||||
GroupID: &gid,
|
|
||||||
Tokens: tokens,
|
|
||||||
RequestCount: 1,
|
|
||||||
SizeTier: sizeTier,
|
|
||||||
RateMultiplier: multiplier,
|
|
||||||
ServiceTier: serviceTier,
|
|
||||||
Resolver: s.resolver,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) calculateOpenAIImageCost(
|
func (s *OpenAIGatewayService) calculateOpenAIImageCost(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
billingModel string,
|
billingModel string,
|
||||||
@@ -4679,7 +4650,8 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
|
|||||||
result *OpenAIForwardResult,
|
result *OpenAIForwardResult,
|
||||||
multiplier float64,
|
multiplier float64,
|
||||||
) *CostBreakdown {
|
) *CostBreakdown {
|
||||||
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil {
|
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil &&
|
||||||
|
(resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) {
|
||||||
gid := apiKey.Group.ID
|
gid := apiKey.Group.ID
|
||||||
cost, err := s.billingService.CalculateCostUnified(CostInput{
|
cost, err := s.billingService.CalculateCostUnified(CostInput{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
@@ -4720,17 +4692,6 @@ func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasOpenAIImageUsageTokens(result *OpenAIForwardResult) bool {
|
|
||||||
if result == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return result.Usage.InputTokens > 0 ||
|
|
||||||
result.Usage.OutputTokens > 0 ||
|
|
||||||
result.Usage.CacheCreationInputTokens > 0 ||
|
|
||||||
result.Usage.CacheReadInputTokens > 0 ||
|
|
||||||
result.Usage.ImageOutputTokens > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
|
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
|
||||||
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
|
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
|
||||||
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
853
backend/internal/service/openai_images_responses.go
Normal file
853
backend/internal/service/openai_images_responses.go
Normal file
@@ -0,0 +1,853 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIResponsesImageResult struct {
|
||||||
|
Result string
|
||||||
|
RevisedPrompt string
|
||||||
|
OutputFormat string
|
||||||
|
Size string
|
||||||
|
Background string
|
||||||
|
Quality string
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string {
|
||||||
|
if strings.TrimSpace(result.Result) != "" {
|
||||||
|
return strings.TrimSpace(result.OutputFormat) + "|" + strings.TrimSpace(result.Result)
|
||||||
|
}
|
||||||
|
return "item:" + strings.TrimSpace(itemID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendOpenAIResponsesImageResultDedup(results *[]openAIResponsesImageResult, seen map[string]struct{}, itemID string, result openAIResponsesImageResult) bool {
|
||||||
|
if results == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
key := openAIResponsesImageResultKey(itemID, result)
|
||||||
|
if key != "" {
|
||||||
|
if _, exists := seen[key]; exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
}
|
||||||
|
*results = append(*results, result)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeOpenAIResponsesImageMeta(dst *openAIResponsesImageResult, src openAIResponsesImageResult) {
|
||||||
|
if dst == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if trimmed := strings.TrimSpace(src.OutputFormat); trimmed != "" {
|
||||||
|
dst.OutputFormat = trimmed
|
||||||
|
}
|
||||||
|
if trimmed := strings.TrimSpace(src.Size); trimmed != "" {
|
||||||
|
dst.Size = trimmed
|
||||||
|
}
|
||||||
|
if trimmed := strings.TrimSpace(src.Background); trimmed != "" {
|
||||||
|
dst.Background = trimmed
|
||||||
|
}
|
||||||
|
if trimmed := strings.TrimSpace(src.Quality); trimmed != "" {
|
||||||
|
dst.Quality = trimmed
|
||||||
|
}
|
||||||
|
if trimmed := strings.TrimSpace(src.Model); trimmed != "" {
|
||||||
|
dst.Model = trimmed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAIResponsesImageMetaFromLifecycleEvent(payload []byte) (openAIResponsesImageResult, int64, bool) {
|
||||||
|
switch gjson.GetBytes(payload, "type").String() {
|
||||||
|
case "response.created", "response.in_progress", "response.completed":
|
||||||
|
default:
|
||||||
|
return openAIResponsesImageResult{}, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
response := gjson.GetBytes(payload, "response")
|
||||||
|
if !response.Exists() {
|
||||||
|
return openAIResponsesImageResult{}, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := openAIResponsesImageResult{
|
||||||
|
OutputFormat: strings.TrimSpace(response.Get("tools.0.output_format").String()),
|
||||||
|
Size: strings.TrimSpace(response.Get("tools.0.size").String()),
|
||||||
|
Background: strings.TrimSpace(response.Get("tools.0.background").String()),
|
||||||
|
Quality: strings.TrimSpace(response.Get("tools.0.quality").String()),
|
||||||
|
Model: strings.TrimSpace(response.Get("tools.0.model").String()),
|
||||||
|
}
|
||||||
|
return meta, response.Get("created_at").Int(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIImagesStreamPartialPayload(
|
||||||
|
eventType string,
|
||||||
|
b64 string,
|
||||||
|
partialImageIndex int64,
|
||||||
|
responseFormat string,
|
||||||
|
createdAt int64,
|
||||||
|
meta openAIResponsesImageResult,
|
||||||
|
) []byte {
|
||||||
|
if createdAt <= 0 {
|
||||||
|
createdAt = time.Now().Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte(`{"type":"","created_at":0,"partial_image_index":0,"b64_json":""}`)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "type", eventType)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "created_at", createdAt)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "partial_image_index", partialImageIndex)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "b64_json", b64)
|
||||||
|
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(meta.OutputFormat)+";base64,"+b64)
|
||||||
|
}
|
||||||
|
if meta.Background != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "background", meta.Background)
|
||||||
|
}
|
||||||
|
if meta.OutputFormat != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "output_format", meta.OutputFormat)
|
||||||
|
}
|
||||||
|
if meta.Quality != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "quality", meta.Quality)
|
||||||
|
}
|
||||||
|
if meta.Size != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "size", meta.Size)
|
||||||
|
}
|
||||||
|
if meta.Model != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "model", meta.Model)
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIImagesStreamCompletedPayload(
|
||||||
|
eventType string,
|
||||||
|
img openAIResponsesImageResult,
|
||||||
|
responseFormat string,
|
||||||
|
createdAt int64,
|
||||||
|
usageRaw []byte,
|
||||||
|
) []byte {
|
||||||
|
if createdAt <= 0 {
|
||||||
|
createdAt = time.Now().Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte(`{"type":"","created_at":0,"b64_json":""}`)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "type", eventType)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "created_at", createdAt)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "b64_json", img.Result)
|
||||||
|
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result)
|
||||||
|
}
|
||||||
|
if img.Background != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "background", img.Background)
|
||||||
|
}
|
||||||
|
if img.OutputFormat != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "output_format", img.OutputFormat)
|
||||||
|
}
|
||||||
|
if img.Quality != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "quality", img.Quality)
|
||||||
|
}
|
||||||
|
if img.Size != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "size", img.Size)
|
||||||
|
}
|
||||||
|
if img.Model != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "model", img.Model)
|
||||||
|
}
|
||||||
|
if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) {
|
||||||
|
payload, _ = sjson.SetRawBytes(payload, "usage", usageRaw)
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIImageOutputMIMEType(outputFormat string) string {
|
||||||
|
if outputFormat == "" {
|
||||||
|
return "image/png"
|
||||||
|
}
|
||||||
|
if strings.Contains(outputFormat, "/") {
|
||||||
|
return outputFormat
|
||||||
|
}
|
||||||
|
switch strings.ToLower(strings.TrimSpace(outputFormat)) {
|
||||||
|
case "png":
|
||||||
|
return "image/png"
|
||||||
|
case "jpg", "jpeg":
|
||||||
|
return "image/jpeg"
|
||||||
|
case "webp":
|
||||||
|
return "image/webp"
|
||||||
|
default:
|
||||||
|
return "image/png"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIImageUploadToDataURL(upload OpenAIImagesUpload) (string, error) {
|
||||||
|
if len(upload.Data) == 0 {
|
||||||
|
return "", fmt.Errorf("upload %q is empty", strings.TrimSpace(upload.FileName))
|
||||||
|
}
|
||||||
|
contentType := strings.TrimSpace(upload.ContentType)
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = http.DetectContentType(upload.Data)
|
||||||
|
}
|
||||||
|
return "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(upload.Data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel string) ([]byte, error) {
|
||||||
|
if parsed == nil {
|
||||||
|
return nil, fmt.Errorf("parsed images request is required")
|
||||||
|
}
|
||||||
|
prompt := strings.TrimSpace(parsed.Prompt)
|
||||||
|
if prompt == "" {
|
||||||
|
return nil, fmt.Errorf("prompt is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
inputImages := make([]string, 0, len(parsed.InputImageURLs)+len(parsed.Uploads))
|
||||||
|
for _, imageURL := range parsed.InputImageURLs {
|
||||||
|
if trimmed := strings.TrimSpace(imageURL); trimmed != "" {
|
||||||
|
inputImages = append(inputImages, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, upload := range parsed.Uploads {
|
||||||
|
dataURL, err := openAIImageUploadToDataURL(upload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
inputImages = append(inputImages, dataURL)
|
||||||
|
}
|
||||||
|
if parsed.IsEdits() && len(inputImages) == 0 {
|
||||||
|
return nil, fmt.Errorf("image input is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := []byte(`{"instructions":"","stream":true,"reasoning":{"effort":"medium","summary":"auto"},"parallel_tool_calls":true,"include":["reasoning.encrypted_content"],"model":"","store":false,"tool_choice":{"type":"image_generation"}}`)
|
||||||
|
req, _ = sjson.SetBytes(req, "model", openAIImagesResponsesMainModel)
|
||||||
|
|
||||||
|
input := []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`)
|
||||||
|
input, _ = sjson.SetBytes(input, "0.content.0.text", prompt)
|
||||||
|
for index, imageURL := range inputImages {
|
||||||
|
part := []byte(`{"type":"input_image","image_url":""}`)
|
||||||
|
part, _ = sjson.SetBytes(part, "image_url", imageURL)
|
||||||
|
input, _ = sjson.SetRawBytes(input, fmt.Sprintf("0.content.%d", index+1), part)
|
||||||
|
}
|
||||||
|
req, _ = sjson.SetRawBytes(req, "input", input)
|
||||||
|
|
||||||
|
action := "generate"
|
||||||
|
if parsed.IsEdits() {
|
||||||
|
action = "edit"
|
||||||
|
}
|
||||||
|
tool := []byte(`{"type":"image_generation","action":"","model":""}`)
|
||||||
|
tool, _ = sjson.SetBytes(tool, "action", action)
|
||||||
|
tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel))
|
||||||
|
|
||||||
|
for _, field := range []struct {
|
||||||
|
path string
|
||||||
|
value string
|
||||||
|
}{
|
||||||
|
{path: "size", value: parsed.Size},
|
||||||
|
{path: "quality", value: parsed.Quality},
|
||||||
|
{path: "background", value: parsed.Background},
|
||||||
|
{path: "output_format", value: parsed.OutputFormat},
|
||||||
|
{path: "moderation", value: parsed.Moderation},
|
||||||
|
{path: "style", value: parsed.Style},
|
||||||
|
} {
|
||||||
|
if trimmed := strings.TrimSpace(field.value); trimmed != "" {
|
||||||
|
tool, _ = sjson.SetBytes(tool, field.path, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if parsed.OutputCompression != nil {
|
||||||
|
tool, _ = sjson.SetBytes(tool, "output_compression", *parsed.OutputCompression)
|
||||||
|
}
|
||||||
|
if parsed.PartialImages != nil {
|
||||||
|
tool, _ = sjson.SetBytes(tool, "partial_images", *parsed.PartialImages)
|
||||||
|
}
|
||||||
|
|
||||||
|
maskImageURL := strings.TrimSpace(parsed.MaskImageURL)
|
||||||
|
if parsed.MaskUpload != nil {
|
||||||
|
dataURL, err := openAIImageUploadToDataURL(*parsed.MaskUpload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
maskImageURL = dataURL
|
||||||
|
}
|
||||||
|
if maskImageURL != "" {
|
||||||
|
tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", maskImageURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ = sjson.SetRawBytes(req, "tools", []byte(`[]`))
|
||||||
|
req, _ = sjson.SetRawBytes(req, "tools.-1", tool)
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) {
|
||||||
|
if gjson.GetBytes(payload, "type").String() != "response.completed" {
|
||||||
|
return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type")
|
||||||
|
}
|
||||||
|
|
||||||
|
createdAt := gjson.GetBytes(payload, "response.created_at").Int()
|
||||||
|
if createdAt <= 0 {
|
||||||
|
createdAt = time.Now().Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
results []openAIResponsesImageResult
|
||||||
|
firstMeta openAIResponsesImageResult
|
||||||
|
)
|
||||||
|
output := gjson.GetBytes(payload, "response.output")
|
||||||
|
if output.IsArray() {
|
||||||
|
for _, item := range output.Array() {
|
||||||
|
if item.Get("type").String() != "image_generation_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result := strings.TrimSpace(item.Get("result").String())
|
||||||
|
if result == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry := openAIResponsesImageResult{
|
||||||
|
Result: result,
|
||||||
|
RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
|
||||||
|
OutputFormat: strings.TrimSpace(item.Get("output_format").String()),
|
||||||
|
Size: strings.TrimSpace(item.Get("size").String()),
|
||||||
|
Background: strings.TrimSpace(item.Get("background").String()),
|
||||||
|
Quality: strings.TrimSpace(item.Get("quality").String()),
|
||||||
|
}
|
||||||
|
if len(results) == 0 {
|
||||||
|
firstMeta = entry
|
||||||
|
}
|
||||||
|
results = append(results, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var usageRaw []byte
|
||||||
|
if usage := gjson.GetBytes(payload, "response.tool_usage.image_gen"); usage.Exists() && usage.IsObject() {
|
||||||
|
usageRaw = []byte(usage.Raw)
|
||||||
|
}
|
||||||
|
return results, createdAt, usageRaw, firstMeta, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAIImageFromResponsesOutputItemDone(payload []byte) (openAIResponsesImageResult, string, bool, error) {
|
||||||
|
if gjson.GetBytes(payload, "type").String() != "response.output_item.done" {
|
||||||
|
return openAIResponsesImageResult{}, "", false, fmt.Errorf("unexpected event type")
|
||||||
|
}
|
||||||
|
|
||||||
|
item := gjson.GetBytes(payload, "item")
|
||||||
|
if !item.Exists() || item.Get("type").String() != "image_generation_call" {
|
||||||
|
return openAIResponsesImageResult{}, "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := strings.TrimSpace(item.Get("result").String())
|
||||||
|
if result == "" {
|
||||||
|
return openAIResponsesImageResult{}, "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := openAIResponsesImageResult{
|
||||||
|
Result: result,
|
||||||
|
RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
|
||||||
|
OutputFormat: strings.TrimSpace(item.Get("output_format").String()),
|
||||||
|
Size: strings.TrimSpace(item.Get("size").String()),
|
||||||
|
Background: strings.TrimSpace(item.Get("background").String()),
|
||||||
|
Quality: strings.TrimSpace(item.Get("quality").String()),
|
||||||
|
}
|
||||||
|
return entry, strings.TrimSpace(item.Get("id").String()), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, bool, error) {
|
||||||
|
var (
|
||||||
|
fallbackResults []openAIResponsesImageResult
|
||||||
|
fallbackSeen = make(map[string]struct{})
|
||||||
|
createdAt int64
|
||||||
|
usageRaw []byte
|
||||||
|
foundFinal bool
|
||||||
|
responseMeta openAIResponsesImageResult
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, line := range bytes.Split(body, []byte("\n")) {
|
||||||
|
line = bytes.TrimRight(line, "\r")
|
||||||
|
data, ok := extractOpenAISSEDataLine(string(line))
|
||||||
|
if !ok || data == "" || data == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := []byte(data)
|
||||||
|
if !gjson.ValidBytes(payload) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok {
|
||||||
|
mergeOpenAIResponsesImageMeta(&responseMeta, meta)
|
||||||
|
if eventCreatedAt > 0 {
|
||||||
|
createdAt = eventCreatedAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch gjson.GetBytes(payload, "type").String() {
|
||||||
|
case "response.output_item.done":
|
||||||
|
result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, nil, openAIResponsesImageResult{}, false, err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
mergeOpenAIResponsesImageMeta(&result, responseMeta)
|
||||||
|
appendOpenAIResponsesImageResultDedup(&fallbackResults, fallbackSeen, itemID, result)
|
||||||
|
}
|
||||||
|
case "response.completed":
|
||||||
|
results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, nil, openAIResponsesImageResult{}, false, err
|
||||||
|
}
|
||||||
|
foundFinal = true
|
||||||
|
if completedAt > 0 {
|
||||||
|
createdAt = completedAt
|
||||||
|
}
|
||||||
|
if len(completedUsageRaw) > 0 {
|
||||||
|
usageRaw = completedUsageRaw
|
||||||
|
}
|
||||||
|
if len(results) > 0 {
|
||||||
|
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
|
||||||
|
return results, createdAt, usageRaw, firstMeta, true, nil
|
||||||
|
}
|
||||||
|
if len(fallbackResults) > 0 {
|
||||||
|
firstMeta = fallbackResults[0]
|
||||||
|
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
|
||||||
|
return fallbackResults, createdAt, usageRaw, firstMeta, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fallbackResults) > 0 {
|
||||||
|
firstMeta := fallbackResults[0]
|
||||||
|
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
|
||||||
|
return fallbackResults, createdAt, usageRaw, firstMeta, foundFinal, nil
|
||||||
|
}
|
||||||
|
return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIImagesAPIResponse(
|
||||||
|
results []openAIResponsesImageResult,
|
||||||
|
createdAt int64,
|
||||||
|
usageRaw []byte,
|
||||||
|
firstMeta openAIResponsesImageResult,
|
||||||
|
responseFormat string,
|
||||||
|
) ([]byte, error) {
|
||||||
|
if createdAt <= 0 {
|
||||||
|
createdAt = time.Now().Unix()
|
||||||
|
}
|
||||||
|
out := []byte(`{"created":0,"data":[]}`)
|
||||||
|
out, _ = sjson.SetBytes(out, "created", createdAt)
|
||||||
|
|
||||||
|
format := strings.ToLower(strings.TrimSpace(responseFormat))
|
||||||
|
if format == "" {
|
||||||
|
format = "b64_json"
|
||||||
|
}
|
||||||
|
for _, img := range results {
|
||||||
|
item := []byte(`{}`)
|
||||||
|
if format == "url" {
|
||||||
|
item, _ = sjson.SetBytes(item, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result)
|
||||||
|
} else {
|
||||||
|
item, _ = sjson.SetBytes(item, "b64_json", img.Result)
|
||||||
|
}
|
||||||
|
if img.RevisedPrompt != "" {
|
||||||
|
item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt)
|
||||||
|
}
|
||||||
|
out, _ = sjson.SetRawBytes(out, "data.-1", item)
|
||||||
|
}
|
||||||
|
if firstMeta.Background != "" {
|
||||||
|
out, _ = sjson.SetBytes(out, "background", firstMeta.Background)
|
||||||
|
}
|
||||||
|
if firstMeta.OutputFormat != "" {
|
||||||
|
out, _ = sjson.SetBytes(out, "output_format", firstMeta.OutputFormat)
|
||||||
|
}
|
||||||
|
if firstMeta.Quality != "" {
|
||||||
|
out, _ = sjson.SetBytes(out, "quality", firstMeta.Quality)
|
||||||
|
}
|
||||||
|
if firstMeta.Size != "" {
|
||||||
|
out, _ = sjson.SetBytes(out, "size", firstMeta.Size)
|
||||||
|
}
|
||||||
|
if firstMeta.Model != "" {
|
||||||
|
out, _ = sjson.SetBytes(out, "model", firstMeta.Model)
|
||||||
|
}
|
||||||
|
if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) {
|
||||||
|
out, _ = sjson.SetRawBytes(out, "usage", usageRaw)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIImagesStreamPrefix(parsed *OpenAIImagesRequest) string {
|
||||||
|
if parsed != nil && parsed.IsEdits() {
|
||||||
|
return "image_edit"
|
||||||
|
}
|
||||||
|
return "image_generation"
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIImagesStreamErrorBody(message string) []byte {
|
||||||
|
body := []byte(`{"type":"error","error":{"type":"upstream_error","message":""}}`)
|
||||||
|
if strings.TrimSpace(message) == "" {
|
||||||
|
message = "upstream request failed"
|
||||||
|
}
|
||||||
|
body, _ = sjson.SetBytes(body, "error.message", message)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flusher http.Flusher, eventName string, payload []byte) error {
|
||||||
|
if strings.TrimSpace(eventName) != "" {
|
||||||
|
if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
responseFormat string,
|
||||||
|
fallbackModel string,
|
||||||
|
) (OpenAIUsage, int, error) {
|
||||||
|
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||||
|
if err != nil {
|
||||||
|
return OpenAIUsage{}, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var usage OpenAIUsage
|
||||||
|
for _, line := range bytes.Split(body, []byte("\n")) {
|
||||||
|
line = bytes.TrimRight(line, "\r")
|
||||||
|
data, ok := extractOpenAISSEDataLine(string(line))
|
||||||
|
if !ok || data == "" || data == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dataBytes := []byte(data)
|
||||||
|
s.parseSSEUsageBytes(dataBytes, &usage)
|
||||||
|
}
|
||||||
|
results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body)
|
||||||
|
if err != nil {
|
||||||
|
return OpenAIUsage{}, 0, err
|
||||||
|
}
|
||||||
|
if len(results) == 0 {
|
||||||
|
return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(firstMeta.Model) == "" {
|
||||||
|
firstMeta.Model = strings.TrimSpace(fallbackModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat)
|
||||||
|
if err != nil {
|
||||||
|
return OpenAIUsage{}, 0, err
|
||||||
|
}
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
|
c.Data(resp.StatusCode, "application/json; charset=utf-8", responseBody)
|
||||||
|
return usage, len(results), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
startTime time.Time,
|
||||||
|
responseFormat string,
|
||||||
|
streamPrefix string,
|
||||||
|
fallbackModel string,
|
||||||
|
) (OpenAIUsage, int, *int, error) {
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Status(resp.StatusCode)
|
||||||
|
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
|
||||||
|
}
|
||||||
|
|
||||||
|
format := strings.ToLower(strings.TrimSpace(responseFormat))
|
||||||
|
if format == "" {
|
||||||
|
format = "b64_json"
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
usage := OpenAIUsage{}
|
||||||
|
imageCount := 0
|
||||||
|
var firstTokenMs *int
|
||||||
|
emitted := make(map[string]struct{})
|
||||||
|
pendingResults := make([]openAIResponsesImageResult, 0, 1)
|
||||||
|
pendingSeen := make(map[string]struct{})
|
||||||
|
streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)}
|
||||||
|
var createdAt int64
|
||||||
|
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if len(line) > 0 {
|
||||||
|
trimmedLine := strings.TrimRight(string(line), "\r\n")
|
||||||
|
data, ok := extractOpenAISSEDataLine(trimmedLine)
|
||||||
|
if ok && data != "" && data != "[DONE]" {
|
||||||
|
if firstTokenMs == nil {
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
dataBytes := []byte(data)
|
||||||
|
s.parseSSEUsageBytes(dataBytes, &usage)
|
||||||
|
if gjson.ValidBytes(dataBytes) {
|
||||||
|
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
|
||||||
|
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
|
||||||
|
if eventCreatedAt > 0 {
|
||||||
|
createdAt = eventCreatedAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch gjson.GetBytes(dataBytes, "type").String() {
|
||||||
|
case "response.image_generation_call.partial_image":
|
||||||
|
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
|
||||||
|
if b64 != "" {
|
||||||
|
eventName := streamPrefix + ".partial_image"
|
||||||
|
partialMeta := streamMeta
|
||||||
|
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
|
||||||
|
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
|
||||||
|
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
|
||||||
|
})
|
||||||
|
payload := buildOpenAIImagesStreamPartialPayload(
|
||||||
|
eventName,
|
||||||
|
b64,
|
||||||
|
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
|
||||||
|
format,
|
||||||
|
createdAt,
|
||||||
|
partialMeta,
|
||||||
|
)
|
||||||
|
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||||
|
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "response.output_item.done":
|
||||||
|
img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
|
||||||
|
if extractErr != nil {
|
||||||
|
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||||
|
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
mergeOpenAIResponsesImageMeta(&streamMeta, img)
|
||||||
|
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||||
|
key := openAIResponsesImageResultKey(itemID, img)
|
||||||
|
if _, exists := emitted[key]; exists {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if _, exists := pendingSeen[key]; exists {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pendingSeen[key] = struct{}{}
|
||||||
|
pendingResults = append(pendingResults, img)
|
||||||
|
case "response.completed":
|
||||||
|
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
|
||||||
|
if extractErr != nil {
|
||||||
|
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||||
|
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
|
||||||
|
}
|
||||||
|
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
|
||||||
|
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
|
||||||
|
finalSeen := make(map[string]struct{})
|
||||||
|
for _, img := range results {
|
||||||
|
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||||
|
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||||
|
}
|
||||||
|
for _, img := range pendingResults {
|
||||||
|
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||||
|
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||||
|
}
|
||||||
|
if len(finalResults) == 0 {
|
||||||
|
err = fmt.Errorf("upstream did not return image output")
|
||||||
|
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
|
||||||
|
return OpenAIUsage{}, imageCount, firstTokenMs, err
|
||||||
|
}
|
||||||
|
eventName := streamPrefix + ".completed"
|
||||||
|
for _, img := range finalResults {
|
||||||
|
key := openAIResponsesImageResultKey("", img)
|
||||||
|
if _, exists := emitted[key]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
|
||||||
|
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||||
|
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||||
|
}
|
||||||
|
emitted[key] = struct{}{}
|
||||||
|
}
|
||||||
|
imageCount = len(emitted)
|
||||||
|
return usage, imageCount, firstTokenMs, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
|
||||||
|
return OpenAIUsage{}, imageCount, firstTokenMs, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageCount > 0 {
|
||||||
|
return usage, imageCount, firstTokenMs, nil
|
||||||
|
}
|
||||||
|
if len(pendingResults) > 0 {
|
||||||
|
eventName := streamPrefix + ".completed"
|
||||||
|
for _, img := range pendingResults {
|
||||||
|
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||||
|
key := openAIResponsesImageResultKey("", img)
|
||||||
|
if _, exists := emitted[key]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
|
||||||
|
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||||
|
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||||
|
}
|
||||||
|
emitted[key] = struct{}{}
|
||||||
|
}
|
||||||
|
imageCount = len(emitted)
|
||||||
|
return usage, imageCount, firstTokenMs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
streamErr := fmt.Errorf("stream disconnected before image generation completed")
|
||||||
|
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
|
||||||
|
return OpenAIUsage{}, imageCount, firstTokenMs, streamErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||||
|
ctx context.Context,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
parsed *OpenAIImagesRequest,
|
||||||
|
channelMappedModel string,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
requestModel := strings.TrimSpace(parsed.Model)
|
||||||
|
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
|
||||||
|
requestModel = mapped
|
||||||
|
}
|
||||||
|
if requestModel == "" {
|
||||||
|
requestModel = "gpt-image-2"
|
||||||
|
}
|
||||||
|
if err := validateOpenAIImagesModel(requestModel); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.openai_gateway",
|
||||||
|
"[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d",
|
||||||
|
requestModel,
|
||||||
|
parsed.Endpoint,
|
||||||
|
account.Type,
|
||||||
|
len(parsed.Uploads),
|
||||||
|
)
|
||||||
|
if parsed.N > 1 {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.openai_gateway",
|
||||||
|
"[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s",
|
||||||
|
parsed.N,
|
||||||
|
requestModel,
|
||||||
|
parsed.Endpoint,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, _, err := s.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, requestModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
setOpsUpstreamRequestBody(c, responsesBody)
|
||||||
|
|
||||||
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||||
|
upstreamReq.Header.Set("Accept", "text/event-stream")
|
||||||
|
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
upstreamStart := time.Now()
|
||||||
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
|
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||||
|
if err != nil {
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
})
|
||||||
|
s.handleFailoverSideEffects(ctx, resp, account)
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: respBody,
|
||||||
|
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.handleErrorResponse(ctx, resp, c, account, responsesBody)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
var (
|
||||||
|
usage OpenAIUsage
|
||||||
|
imageCount int
|
||||||
|
firstTokenMs *int
|
||||||
|
)
|
||||||
|
if parsed.Stream {
|
||||||
|
usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if imageCount <= 0 {
|
||||||
|
imageCount = parsed.N
|
||||||
|
}
|
||||||
|
return &OpenAIForwardResult{
|
||||||
|
RequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Usage: usage,
|
||||||
|
Model: requestModel,
|
||||||
|
UpstreamModel: requestModel,
|
||||||
|
Stream: parsed.Stream,
|
||||||
|
ResponseHeaders: resp.Header.Clone(),
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
|
ImageCount: imageCount,
|
||||||
|
ImageSize: parsed.SizeTier,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -3,13 +3,17 @@ package service
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/textproto"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
|
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
|
||||||
@@ -70,6 +74,58 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
|
|||||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
require.NoError(t, writer.WriteField("model", "gpt-image-2"))
|
||||||
|
require.NoError(t, writer.WriteField("prompt", "replace foreground"))
|
||||||
|
require.NoError(t, writer.WriteField("output_format", "png"))
|
||||||
|
require.NoError(t, writer.WriteField("input_fidelity", "high"))
|
||||||
|
require.NoError(t, writer.WriteField("output_compression", "80"))
|
||||||
|
require.NoError(t, writer.WriteField("partial_images", "2"))
|
||||||
|
|
||||||
|
imageHeader := make(textproto.MIMEHeader)
|
||||||
|
imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`)
|
||||||
|
imageHeader.Set("Content-Type", "image/png")
|
||||||
|
imagePart, err := writer.CreatePart(imageHeader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = imagePart.Write([]byte("source-image-bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
maskHeader := make(textproto.MIMEHeader)
|
||||||
|
maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`)
|
||||||
|
maskHeader.Set("Content-Type", "image/png")
|
||||||
|
maskPart, err := writer.CreatePart(maskHeader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = maskPart.Write([]byte("mask-image-bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Len(t, parsed.Uploads, 1)
|
||||||
|
require.NotNil(t, parsed.MaskUpload)
|
||||||
|
require.True(t, parsed.HasMask)
|
||||||
|
require.Equal(t, "png", parsed.OutputFormat)
|
||||||
|
require.Equal(t, "high", parsed.InputFidelity)
|
||||||
|
require.NotNil(t, parsed.OutputCompression)
|
||||||
|
require.Equal(t, 80, *parsed.OutputCompression)
|
||||||
|
require.NotNil(t, parsed.PartialImages)
|
||||||
|
require.Equal(t, 2, *parsed.PartialImages)
|
||||||
|
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) {
|
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
body := []byte(`{"prompt":"draw a cat"}`)
|
body := []byte(`{"prompt":"draw a cat"}`)
|
||||||
@@ -121,6 +177,40 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *te
|
|||||||
require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`)
|
require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSONEditURLs(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{
|
||||||
|
"model":"gpt-image-2",
|
||||||
|
"prompt":"replace the background",
|
||||||
|
"images":[{"image_url":"https://example.com/source.png"}],
|
||||||
|
"mask":{"image_url":"https://example.com/mask.png"},
|
||||||
|
"input_fidelity":"high",
|
||||||
|
"output_compression":90,
|
||||||
|
"partial_images":2,
|
||||||
|
"response_format":"url"
|
||||||
|
}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, []string{"https://example.com/source.png"}, parsed.InputImageURLs)
|
||||||
|
require.Equal(t, "https://example.com/mask.png", parsed.MaskImageURL)
|
||||||
|
require.Equal(t, "high", parsed.InputFidelity)
|
||||||
|
require.NotNil(t, parsed.OutputCompression)
|
||||||
|
require.Equal(t, 90, *parsed.OutputCompression)
|
||||||
|
require.NotNil(t, parsed.PartialImages)
|
||||||
|
require.Equal(t, 2, *parsed.PartialImages)
|
||||||
|
require.True(t, parsed.HasMask)
|
||||||
|
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) {
|
func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) {
|
||||||
items := collectOpenAIImagePointers([]byte(`{
|
items := collectOpenAIImagePointers([]byte(`{
|
||||||
"revised_prompt": "cat astronaut",
|
"revised_prompt": "cat astronaut",
|
||||||
@@ -157,3 +247,472 @@ func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []byte("ABC"), data)
|
require.Equal(t, []byte("ABC"), data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityBasic))
|
||||||
|
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIImageTestSSEEvent struct {
|
||||||
|
Name string
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpenAIImageTestSSEEvents(body string) []openAIImageTestSSEEvent {
|
||||||
|
chunks := strings.Split(body, "\n\n")
|
||||||
|
events := make([]openAIImageTestSSEEvent, 0, len(chunks))
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
chunk = strings.TrimSpace(chunk)
|
||||||
|
if chunk == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var event openAIImageTestSSEEvent
|
||||||
|
for _, line := range strings.Split(chunk, "\n") {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(line, "event: "):
|
||||||
|
event.Name = strings.TrimSpace(strings.TrimPrefix(line, "event: "))
|
||||||
|
case strings.HasPrefix(line, "data: "):
|
||||||
|
event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if event.Name != "" || event.Data != "" {
|
||||||
|
events = append(events, event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
|
func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string) (openAIImageTestSSEEvent, bool) {
|
||||||
|
for _, event := range events {
|
||||||
|
if event.Name == name {
|
||||||
|
return event, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return openAIImageTestSSEEvent{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
c.Set("api_key", &APIKey{ID: 42})
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/event-stream"},
|
||||||
|
"X-Request-Id": []string{"req_img_123"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc.httpUpstream = upstream
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token-123",
|
||||||
|
"chatgpt_account_id": "acct-123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "gpt-image-2", result.Model)
|
||||||
|
require.Equal(t, "gpt-image-2", result.UpstreamModel)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
require.Equal(t, 11, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 22, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 7, result.Usage.ImageOutputTokens)
|
||||||
|
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.Equal(t, chatgptCodexURL, upstream.lastReq.URL.String())
|
||||||
|
require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
|
||||||
|
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type"))
|
||||||
|
require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept"))
|
||||||
|
require.Equal(t, "acct-123", upstream.lastReq.Header.Get("chatgpt-account-id"))
|
||||||
|
require.Equal(t, "responses=experimental", upstream.lastReq.Header.Get("OpenAI-Beta"))
|
||||||
|
|
||||||
|
require.Equal(t, openAIImagesResponsesMainModel, gjson.GetBytes(upstream.lastBody, "model").String())
|
||||||
|
require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
||||||
|
require.Equal(t, "image_generation", gjson.GetBytes(upstream.lastBody, "tools.0.type").String())
|
||||||
|
require.Equal(t, "generate", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
|
||||||
|
require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String())
|
||||||
|
require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String())
|
||||||
|
require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists())
|
||||||
|
require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String())
|
||||||
|
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
|
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/event-stream"},
|
||||||
|
"X-Request-Id": []string{"req_img_stream"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
"data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000001,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
||||||
|
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\",\"background\":\"auto\"}\n\n" +
|
||||||
|
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000001,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc.httpUpstream = upstream
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 2,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token-123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.Stream)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
|
||||||
|
partial, ok := findOpenAIImageTestSSEEvent(events, "image_generation.partial_image")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "image_generation.partial_image", gjson.Get(partial.Data, "type").String())
|
||||||
|
require.Equal(t, int64(1710000001), gjson.Get(partial.Data, "created_at").Int())
|
||||||
|
require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String())
|
||||||
|
require.Equal(t, "data:image/png;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String())
|
||||||
|
require.Equal(t, "png", gjson.Get(partial.Data, "output_format").String())
|
||||||
|
require.Equal(t, "high", gjson.Get(partial.Data, "quality").String())
|
||||||
|
require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String())
|
||||||
|
require.Equal(t, "auto", gjson.Get(partial.Data, "background").String())
|
||||||
|
|
||||||
|
completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String())
|
||||||
|
require.Equal(t, int64(1710000001), gjson.Get(completed.Data, "created_at").Int())
|
||||||
|
require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String())
|
||||||
|
require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
|
||||||
|
require.Equal(t, "png", gjson.Get(completed.Data, "output_format").String())
|
||||||
|
require.Equal(t, "high", gjson.Get(completed.Data, "quality").String())
|
||||||
|
require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String())
|
||||||
|
require.Equal(t, "auto", gjson.Get(completed.Data, "background").String())
|
||||||
|
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
|
||||||
|
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
require.NoError(t, writer.WriteField("model", "gpt-image-2"))
|
||||||
|
require.NoError(t, writer.WriteField("prompt", "replace background with aurora"))
|
||||||
|
require.NoError(t, writer.WriteField("input_fidelity", "high"))
|
||||||
|
require.NoError(t, writer.WriteField("output_format", "webp"))
|
||||||
|
require.NoError(t, writer.WriteField("quality", "high"))
|
||||||
|
|
||||||
|
imageHeader := make(textproto.MIMEHeader)
|
||||||
|
imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`)
|
||||||
|
imageHeader.Set("Content-Type", "image/png")
|
||||||
|
imagePart, err := writer.CreatePart(imageHeader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = imagePart.Write([]byte("png-image-content"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
maskHeader := make(textproto.MIMEHeader)
|
||||||
|
maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`)
|
||||||
|
maskHeader.Set("Content-Type", "image/png")
|
||||||
|
maskPart, err := writer.CreatePart(maskHeader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = maskPart.Write([]byte("png-mask-content"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
c.Set("api_key", &APIKey{ID: 100})
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/event-stream"},
|
||||||
|
"X-Request-Id": []string{"req_img_edit_123"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000002,\"usage\":{\"input_tokens\":13,\"output_tokens\":21,\"output_tokens_details\":{\"image_tokens\":8}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\",\"quality\":\"high\"}]}}\n\n" +
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc.httpUpstream = upstream
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 3,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token-123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
|
||||||
|
require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
|
||||||
|
require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.input_fidelity").Exists())
|
||||||
|
require.Equal(t, "webp", gjson.GetBytes(upstream.lastBody, "tools.0.output_format").String())
|
||||||
|
require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String(), "data:image/png;base64,"))
|
||||||
|
require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String(), "data:image/png;base64,"))
|
||||||
|
require.Equal(t, "replace background with aurora", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
|
||||||
|
require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
|
require.Equal(t, "replace background with aurora", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{
|
||||||
|
"model":"gpt-image-2",
|
||||||
|
"prompt":"replace background with aurora",
|
||||||
|
"images":[{"image_url":"https://example.com/source.png"}],
|
||||||
|
"mask":{"image_url":"https://example.com/mask.png"},
|
||||||
|
"stream":true,
|
||||||
|
"response_format":"url"
|
||||||
|
}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/event-stream"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
"data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000003,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
||||||
|
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"webp\",\"background\":\"transparent\"}\n\n" +
|
||||||
|
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000003,\"usage\":{\"input_tokens\":7,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\"}]}}\n\n" +
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc.httpUpstream = upstream
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 4,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token-123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
|
||||||
|
require.Equal(t, "https://example.com/source.png", gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String())
|
||||||
|
require.Equal(t, "https://example.com/mask.png", gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String())
|
||||||
|
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
|
||||||
|
partial, ok := findOpenAIImageTestSSEEvent(events, "image_edit.partial_image")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "image_edit.partial_image", gjson.Get(partial.Data, "type").String())
|
||||||
|
require.Equal(t, int64(1710000003), gjson.Get(partial.Data, "created_at").Int())
|
||||||
|
require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String())
|
||||||
|
require.Equal(t, "data:image/webp;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String())
|
||||||
|
require.Equal(t, "webp", gjson.Get(partial.Data, "output_format").String())
|
||||||
|
require.Equal(t, "high", gjson.Get(partial.Data, "quality").String())
|
||||||
|
require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String())
|
||||||
|
require.Equal(t, "transparent", gjson.Get(partial.Data, "background").String())
|
||||||
|
|
||||||
|
completed, ok := findOpenAIImageTestSSEEvent(events, "image_edit.completed")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "image_edit.completed", gjson.Get(completed.Data, "type").String())
|
||||||
|
require.Equal(t, int64(1710000003), gjson.Get(completed.Data, "created_at").Int())
|
||||||
|
require.Equal(t, "ZWRpdGVk", gjson.Get(completed.Data, "b64_json").String())
|
||||||
|
require.Equal(t, "data:image/webp;base64,ZWRpdGVk", gjson.Get(completed.Data, "url").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
|
||||||
|
require.Equal(t, "webp", gjson.Get(completed.Data, "output_format").String())
|
||||||
|
require.Equal(t, "high", gjson.Get(completed.Data, "quality").String())
|
||||||
|
require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String())
|
||||||
|
require.Equal(t, "transparent", gjson.Get(completed.Data, "background").String())
|
||||||
|
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
|
||||||
|
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) {
|
||||||
|
parsed := &OpenAIImagesRequest{
|
||||||
|
Endpoint: openAIImagesGenerationsEndpoint,
|
||||||
|
Model: "gpt-image-2",
|
||||||
|
Prompt: "draw a cat",
|
||||||
|
N: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, body)
|
||||||
|
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String())
|
||||||
|
require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) {
|
||||||
|
parsed := &OpenAIImagesRequest{
|
||||||
|
Endpoint: openAIImagesEditsEndpoint,
|
||||||
|
Model: "gpt-image-2",
|
||||||
|
Prompt: "replace background",
|
||||||
|
InputFidelity: "high",
|
||||||
|
InputImageURLs: []string{
|
||||||
|
"https://example.com/source.png",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, body)
|
||||||
|
require.False(t, gjson.GetBytes(body, "tools.0.input_fidelity").Exists())
|
||||||
|
require.Equal(t, "edit", gjson.GetBytes(body, "tools.0.action").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testing.T) {
|
||||||
|
body := []byte(
|
||||||
|
"data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000004}}\n\n" +
|
||||||
|
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\"}}\n\n" +
|
||||||
|
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000004,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, foundFinal)
|
||||||
|
require.Equal(t, int64(1710000004), createdAt)
|
||||||
|
require.Len(t, results, 1)
|
||||||
|
require.Equal(t, "aGVsbG8=", results[0].Result)
|
||||||
|
require.Equal(t, "draw a cat", results[0].RevisedPrompt)
|
||||||
|
require.Equal(t, "png", firstMeta.OutputFormat)
|
||||||
|
require.JSONEq(t, `{"images":1}`, string(usageRaw))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/event-stream"},
|
||||||
|
"X-Request-Id": []string{"req_img_stream_output_item_done"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" +
|
||||||
|
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000005,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc.httpUpstream = upstream
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 5,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token-123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.Stream)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
|
||||||
|
completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String())
|
||||||
|
require.Equal(t, int64(1710000005), gjson.Get(completed.Data, "created_at").Int())
|
||||||
|
require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String())
|
||||||
|
require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
|
||||||
|
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
|
||||||
|
require.NotContains(t, rec.Body.String(), "event: error")
|
||||||
|
}
|
||||||
|
|||||||
@@ -794,6 +794,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GPT-5.5 回退到 GPT-5.4 定价
|
||||||
|
if strings.HasPrefix(model, "gpt-5.5") {
|
||||||
|
logger.With(zap.String("component", "service.pricing")).
|
||||||
|
Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)"))
|
||||||
|
return openAIGPT54FallbackPricing
|
||||||
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(model, "gpt-5.4-mini") {
|
if strings.HasPrefix(model, "gpt-5.4-mini") {
|
||||||
logger.With(zap.String("component", "service.pricing")).
|
logger.With(zap.String("component", "service.pricing")).
|
||||||
Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-mini(static)"))
|
Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-mini(static)"))
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -23,6 +25,7 @@ type RateLimitService struct {
|
|||||||
geminiQuotaService *GeminiQuotaService
|
geminiQuotaService *GeminiQuotaService
|
||||||
tempUnschedCache TempUnschedCache
|
tempUnschedCache TempUnschedCache
|
||||||
timeoutCounterCache TimeoutCounterCache
|
timeoutCounterCache TimeoutCounterCache
|
||||||
|
openAI403CounterCache OpenAI403CounterCache
|
||||||
settingService *SettingService
|
settingService *SettingService
|
||||||
tokenCacheInvalidator TokenCacheInvalidator
|
tokenCacheInvalidator TokenCacheInvalidator
|
||||||
usageCacheMu sync.RWMutex
|
usageCacheMu sync.RWMutex
|
||||||
@@ -52,6 +55,12 @@ type geminiUsageTotalsBatchProvider interface {
|
|||||||
|
|
||||||
const geminiPrecheckCacheTTL = time.Minute
|
const geminiPrecheckCacheTTL = time.Minute
|
||||||
|
|
||||||
|
const (
|
||||||
|
openAI403CooldownMinutesDefault = 10
|
||||||
|
openAI403DisableThreshold = 3
|
||||||
|
openAI403CounterWindowMinutes = 180
|
||||||
|
)
|
||||||
|
|
||||||
// NewRateLimitService 创建RateLimitService实例
|
// NewRateLimitService 创建RateLimitService实例
|
||||||
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
|
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
|
||||||
return &RateLimitService{
|
return &RateLimitService{
|
||||||
@@ -69,6 +78,11 @@ func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) {
|
|||||||
s.timeoutCounterCache = cache
|
s.timeoutCounterCache = cache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetOpenAI403CounterCache 设置 OpenAI 403 连续失败计数器(可选依赖)
|
||||||
|
func (s *RateLimitService) SetOpenAI403CounterCache(cache OpenAI403CounterCache) {
|
||||||
|
s.openAI403CounterCache = cache
|
||||||
|
}
|
||||||
|
|
||||||
// SetSettingService 设置系统设置服务(可选依赖)
|
// SetSettingService 设置系统设置服务(可选依赖)
|
||||||
func (s *RateLimitService) SetSettingService(settingService *SettingService) {
|
func (s *RateLimitService) SetSettingService(settingService *SettingService) {
|
||||||
s.settingService = settingService
|
s.settingService = settingService
|
||||||
@@ -655,6 +669,30 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
|
|||||||
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
|
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildForbiddenErrorMessage(prefix string, upstreamMsg string, responseBody []byte, fallback string) string {
|
||||||
|
prefix = strings.TrimSpace(prefix)
|
||||||
|
if prefix != "" && !strings.HasSuffix(prefix, " ") {
|
||||||
|
prefix += " "
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg := strings.TrimSpace(upstreamMsg); msg != "" {
|
||||||
|
return prefix + msg
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBody := bytes.TrimSpace(responseBody)
|
||||||
|
if len(rawBody) > 0 {
|
||||||
|
if json.Valid(rawBody) {
|
||||||
|
var compact bytes.Buffer
|
||||||
|
if err := json.Compact(&compact, rawBody); err == nil {
|
||||||
|
return prefix + truncateForLog(compact.Bytes(), 512)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prefix + truncateForLog(rawBody, 512)
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix + fallback
|
||||||
|
}
|
||||||
|
|
||||||
// handle403 处理 403 Forbidden 错误
|
// handle403 处理 403 Forbidden 错误
|
||||||
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
|
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
|
||||||
// 其他平台保持原有 SetError 行为。
|
// 其他平台保持原有 SetError 行为。
|
||||||
@@ -662,15 +700,64 @@ func (s *RateLimitService) handle403(ctx context.Context, account *Account, upst
|
|||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
|
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
|
||||||
}
|
}
|
||||||
// 非 Antigravity 平台:保持原有行为
|
if account.Platform == PlatformOpenAI {
|
||||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
return s.handleOpenAI403(ctx, account, upstreamMsg, responseBody)
|
||||||
if upstreamMsg != "" {
|
|
||||||
msg = "Access forbidden (403): " + upstreamMsg
|
|
||||||
}
|
}
|
||||||
|
// 非 Antigravity 平台:保持原有行为
|
||||||
|
msg := buildForbiddenErrorMessage(
|
||||||
|
"Access forbidden (403):",
|
||||||
|
upstreamMsg,
|
||||||
|
responseBody,
|
||||||
|
"account may be suspended or lack permissions",
|
||||||
|
)
|
||||||
s.handleAuthError(ctx, account, msg)
|
s.handleAuthError(ctx, account, msg)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
|
||||||
|
msg := buildForbiddenErrorMessage(
|
||||||
|
"Access forbidden (403):",
|
||||||
|
upstreamMsg,
|
||||||
|
responseBody,
|
||||||
|
"account may be suspended or lack permissions",
|
||||||
|
)
|
||||||
|
|
||||||
|
if s.openAI403CounterCache == nil {
|
||||||
|
s.handleAuthError(ctx, account, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := s.openAI403CounterCache.IncrementOpenAI403Count(ctx, account.ID, openAI403CounterWindowMinutes)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("openai_403_increment_failed", "account_id", account.ID, "error", err)
|
||||||
|
s.handleAuthError(ctx, account, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if count >= openAI403DisableThreshold {
|
||||||
|
msg = fmt.Sprintf("%s | consecutive_403=%d/%d", msg, count, openAI403DisableThreshold)
|
||||||
|
s.handleAuthError(ctx, account, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
|
||||||
|
reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
|
||||||
|
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||||
|
slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
||||||
|
s.handleAuthError(ctx, account, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Warn(
|
||||||
|
"openai_403_temp_unschedulable",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"until", until,
|
||||||
|
"count", count,
|
||||||
|
"threshold", openAI403DisableThreshold,
|
||||||
|
)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// handleAntigravity403 处理 Antigravity 平台的 403 错误
|
// handleAntigravity403 处理 Antigravity 平台的 403 错误
|
||||||
// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复)
|
// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复)
|
||||||
// violation(违规封号)→ 永久 SetError(需人工处理)
|
// violation(违规封号)→ 永久 SetError(需人工处理)
|
||||||
@@ -681,10 +768,12 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
|
|||||||
switch fbType {
|
switch fbType {
|
||||||
case forbiddenTypeValidation:
|
case forbiddenTypeValidation:
|
||||||
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
|
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
|
||||||
msg := "Validation required (403): account needs Google verification"
|
msg := buildForbiddenErrorMessage(
|
||||||
if upstreamMsg != "" {
|
"Validation required (403):",
|
||||||
msg = "Validation required (403): " + upstreamMsg
|
upstreamMsg,
|
||||||
}
|
responseBody,
|
||||||
|
"account needs Google verification",
|
||||||
|
)
|
||||||
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
|
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
|
||||||
msg += " | validation_url: " + validationURL
|
msg += " | validation_url: " + validationURL
|
||||||
}
|
}
|
||||||
@@ -693,19 +782,23 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
|
|||||||
|
|
||||||
case forbiddenTypeViolation:
|
case forbiddenTypeViolation:
|
||||||
// 违规封号: 永久禁用,需人工处理
|
// 违规封号: 永久禁用,需人工处理
|
||||||
msg := "Account violation (403): terms of service violation"
|
msg := buildForbiddenErrorMessage(
|
||||||
if upstreamMsg != "" {
|
"Account violation (403):",
|
||||||
msg = "Account violation (403): " + upstreamMsg
|
upstreamMsg,
|
||||||
}
|
responseBody,
|
||||||
|
"terms of service violation",
|
||||||
|
)
|
||||||
s.handleAuthError(ctx, account, msg)
|
s.handleAuthError(ctx, account, msg)
|
||||||
return true
|
return true
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// 通用 403: 保持原有行为
|
// 通用 403: 保持原有行为
|
||||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
msg := buildForbiddenErrorMessage(
|
||||||
if upstreamMsg != "" {
|
"Access forbidden (403):",
|
||||||
msg = "Access forbidden (403): " + upstreamMsg
|
upstreamMsg,
|
||||||
}
|
responseBody,
|
||||||
|
"account may be suspended or lack permissions",
|
||||||
|
)
|
||||||
s.handleAuthError(ctx, account, msg)
|
s.handleAuthError(ctx, account, msg)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -1221,9 +1314,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
|
|||||||
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
|
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
s.ResetOpenAI403Counter(ctx, accountID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *RateLimitService) ResetOpenAI403Counter(ctx context.Context, accountID int64) {
|
||||||
|
if s == nil || s.openAI403CounterCache == nil || accountID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.openAI403CounterCache.ResetOpenAI403Count(ctx, accountID); err != nil {
|
||||||
|
slog.Warn("openai_403_reset_failed", "account_id", accountID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RecoverAccountState 按需恢复账号的可恢复运行时状态。
|
// RecoverAccountState 按需恢复账号的可恢复运行时状态。
|
||||||
func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) {
|
func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
@@ -1250,6 +1353,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in
|
|||||||
}
|
}
|
||||||
result.ClearedRateLimit = true
|
result.ClearedRateLimit = true
|
||||||
}
|
}
|
||||||
|
if result.ClearedError || result.ClearedRateLimit {
|
||||||
|
s.ResetOpenAI403Counter(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct {
|
|||||||
updateCredentialsCalls int
|
updateCredentialsCalls int
|
||||||
lastCredentials map[string]any
|
lastCredentials map[string]any
|
||||||
lastErrorMsg string
|
lastErrorMsg string
|
||||||
|
lastTempReason string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
@@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error
|
|||||||
|
|
||||||
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
r.tempCalls++
|
r.tempCalls++
|
||||||
|
r.lastTempReason = reason
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAI403CounterCacheStub struct {
|
||||||
|
counts []int64
|
||||||
|
resetCalls []int64
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return 0, s.err
|
||||||
|
}
|
||||||
|
if len(s.counts) == 0 {
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
count := s.counts[0]
|
||||||
|
s.counts = s.counts[1:]
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
|
||||||
|
s.resetCalls = append(s.resetCalls, accountID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
|
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
|
||||||
r.accounts = append(r.accounts, account)
|
r.accounts = append(r.accounts, account)
|
||||||
return r.err
|
return r.err
|
||||||
|
|||||||
64
backend/internal/service/ratelimit_service_403_test.go
Normal file
64
backend/internal/service/ratelimit_service_403_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
counter := &openAI403CounterCacheStub{counts: []int64{1}}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
service.SetOpenAI403CounterCache(counter)
|
||||||
|
account := &Account{
|
||||||
|
ID: 301,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldDisable := service.HandleUpstreamError(
|
||||||
|
context.Background(),
|
||||||
|
account,
|
||||||
|
http.StatusForbidden,
|
||||||
|
http.Header{},
|
||||||
|
[]byte(`{"error":{"message":"temporary edge rejection"}}`),
|
||||||
|
)
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 0, repo.setErrorCalls)
|
||||||
|
require.Equal(t, 1, repo.tempCalls)
|
||||||
|
require.Contains(t, repo.lastTempReason, "temporary edge rejection")
|
||||||
|
require.Contains(t, repo.lastTempReason, "(1/3)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
counter := &openAI403CounterCacheStub{counts: []int64{3}}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
service.SetOpenAI403CounterCache(counter)
|
||||||
|
account := &Account{
|
||||||
|
ID: 302,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldDisable := service.HandleUpstreamError(
|
||||||
|
context.Background(),
|
||||||
|
account,
|
||||||
|
http.StatusForbidden,
|
||||||
|
http.Header{},
|
||||||
|
[]byte(`{"error":{"message":"workspace forbidden by policy"}}`),
|
||||||
|
)
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
|
require.Equal(t, 0, repo.tempCalls)
|
||||||
|
require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy")
|
||||||
|
require.Contains(t, repo.lastErrorMsg, "consecutive_403=3/3")
|
||||||
|
}
|
||||||
@@ -7,6 +7,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
|
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
|
||||||
@@ -259,6 +262,53 @@ func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRateLimitService_HandleUpstreamError_403PreservesOriginalUpstreamMessage(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 201,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldDisable := service.HandleUpstreamError(
|
||||||
|
context.Background(),
|
||||||
|
account,
|
||||||
|
403,
|
||||||
|
http.Header{},
|
||||||
|
[]byte(`{"error":{"message":"workspace forbidden by policy","type":"invalid_request_error"}}`),
|
||||||
|
)
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
|
require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy")
|
||||||
|
require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitService_HandleUpstreamError_403FallsBackToRawBody(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 202,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldDisable := service.HandleUpstreamError(
|
||||||
|
context.Background(),
|
||||||
|
account,
|
||||||
|
403,
|
||||||
|
http.Header{},
|
||||||
|
[]byte(`{"error":{"type":"access_denied","details":{"reason":"ip_blocked"}}}`),
|
||||||
|
)
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
|
require.Contains(t, repo.lastErrorMsg, `"access_denied"`)
|
||||||
|
require.Contains(t, repo.lastErrorMsg, `"ip_blocked"`)
|
||||||
|
require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions")
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
|
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
|
||||||
// Test when only secondary has data, no window_minutes
|
// Test when only secondary has data, no window_minutes
|
||||||
sUsed := 60.0
|
sUsed := 60.0
|
||||||
|
|||||||
@@ -1167,6 +1167,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
|||||||
// 默认配置
|
// 默认配置
|
||||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||||
|
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
|
||||||
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("marshal default subscriptions: %w", err)
|
return nil, fmt.Errorf("marshal default subscriptions: %w", err)
|
||||||
@@ -1538,6 +1539,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
|
|||||||
return s.cfg.Default.UserBalance
|
return s.cfg.Default.UserBalance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制(0 = 不限制)。未配置则返回 0。
|
||||||
|
func (s *SettingService) GetDefaultUserRPMLimit(ctx context.Context) int {
|
||||||
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultUserRPMLimit)
|
||||||
|
if err != nil || value == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if v, err := strconv.Atoi(value); err == nil && v >= 0 {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
|
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
|
||||||
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
|
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
|
||||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
|
||||||
@@ -1706,6 +1719,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
SettingKeyOIDCConnectUserInfoUsernamePath: "",
|
SettingKeyOIDCConnectUserInfoUsernamePath: "",
|
||||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||||
|
SettingKeyDefaultUserRPMLimit: "0",
|
||||||
SettingKeyDefaultSubscriptions: "[]",
|
SettingKeyDefaultSubscriptions: "[]",
|
||||||
SettingKeyAuthSourceDefaultEmailBalance: "0",
|
SettingKeyAuthSourceDefaultEmailBalance: "0",
|
||||||
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
|
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
|
||||||
@@ -1822,6 +1836,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
|
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 {
|
||||||
|
result.DefaultUserRPMLimit = rpm
|
||||||
|
}
|
||||||
|
|
||||||
// 解析浮点数类型
|
// 解析浮点数类型
|
||||||
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
|
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
|
||||||
result.DefaultBalance = balance
|
result.DefaultBalance = balance
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ type SystemSettings struct {
|
|||||||
|
|
||||||
DefaultConcurrency int
|
DefaultConcurrency int
|
||||||
DefaultBalance float64
|
DefaultBalance float64
|
||||||
|
DefaultUserRPMLimit int
|
||||||
DefaultSubscriptions []DefaultSubscriptionSetting
|
DefaultSubscriptions []DefaultSubscriptionSetting
|
||||||
|
|
||||||
// Model fallback configuration
|
// Model fallback configuration
|
||||||
|
|||||||
@@ -49,6 +49,15 @@ type User struct {
|
|||||||
BalanceNotifyExtraEmails []NotifyEmailEntry
|
BalanceNotifyExtraEmails []NotifyEmailEntry
|
||||||
TotalRecharged float64
|
TotalRecharged float64
|
||||||
|
|
||||||
|
// RPMLimit 用户级每分钟请求数上限(0 = 不限制)。仅在所用分组未设置 rpm_limit
|
||||||
|
// 且该 (用户, 分组) 无 rpm_override 时作为全局兜底生效,计数键 rpm:u:{userID}:{min}。
|
||||||
|
RPMLimit int
|
||||||
|
|
||||||
|
// UserGroupRPMOverride 来自 auth cache snapshot 的 (user, group) RPM 覆盖值。
|
||||||
|
// nil = 该 API Key 对应的 (user, group) 无 override;非 nil 时 checkRPM 直接使用,
|
||||||
|
// 避免每请求查 DB。字段不持久化到数据库。
|
||||||
|
UserGroupRPMOverride *int
|
||||||
|
|
||||||
APIKeys []APIKey
|
APIKeys []APIKey
|
||||||
Subscriptions []UserSubscription
|
Subscriptions []UserSubscription
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,14 +2,16 @@ package service
|
|||||||
|
|
||||||
import "context"
|
import "context"
|
||||||
|
|
||||||
// UserGroupRateEntry 分组下用户专属倍率条目
|
// UserGroupRateEntry 分组下用户专属倍率/RPM 条目。
|
||||||
|
// RateMultiplier 与 RPMOverride 均为指针以支持"未设置"语义(NULL)。
|
||||||
type UserGroupRateEntry struct {
|
type UserGroupRateEntry struct {
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
UserName string `json:"user_name"`
|
UserName string `json:"user_name"`
|
||||||
UserEmail string `json:"user_email"`
|
UserEmail string `json:"user_email"`
|
||||||
UserNotes string `json:"user_notes"`
|
UserNotes string `json:"user_notes"`
|
||||||
UserStatus string `json:"user_status"`
|
UserStatus string `json:"user_status"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier *float64 `json:"rate_multiplier,omitempty"`
|
||||||
|
RPMOverride *int `json:"rpm_override,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GroupRateMultiplierInput 批量设置分组倍率的输入条目
|
// GroupRateMultiplierInput 批量设置分组倍率的输入条目
|
||||||
@@ -18,30 +20,44 @@ type GroupRateMultiplierInput struct {
|
|||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserGroupRateRepository 用户专属分组倍率仓储接口
|
// GroupRPMOverrideInput 批量设置分组 RPM override 的输入条目。
|
||||||
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
|
// RPMOverride 为 *int 以支持清除(nil)语义。
|
||||||
|
type GroupRPMOverrideInput struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
RPMOverride *int `json:"rpm_override"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserGroupRateRepository 用户专属分组倍率/RPM 仓储接口。
|
||||||
|
// 允许管理员为特定用户设置分组的专属计费倍率与 RPM 上限,覆盖分组默认值。
|
||||||
type UserGroupRateRepository interface {
|
type UserGroupRateRepository interface {
|
||||||
// GetByUserID 获取用户的所有专属分组倍率
|
// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||||
// 返回 map[groupID]rateMultiplier
|
|
||||||
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
|
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
|
||||||
|
|
||||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
|
||||||
// 如果未设置专属倍率,返回 nil
|
|
||||||
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
|
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
|
||||||
|
|
||||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
|
||||||
|
GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error)
|
||||||
|
|
||||||
|
// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
|
||||||
GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||||
|
|
||||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
// SyncUserGroupRates 同步用户的分组专属倍率;nil 表示清空该分组的 rate_multiplier
|
||||||
// rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
|
|
||||||
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
|
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
|
||||||
|
|
||||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组数据)
|
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组 rate 部分)
|
||||||
SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||||
|
|
||||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
|
// SyncGroupRPMOverrides 批量同步分组的用户专属 RPM(替换整组 rpm_override 部分)。
|
||||||
|
// 条目中 RPMOverride 为 nil 时清空对应行的 rpm_override;非 nil 时 upsert。
|
||||||
|
SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
|
||||||
|
|
||||||
|
// ClearGroupRPMOverrides 清空指定分组的所有 rpm_override(整组 rpm 部分归 NULL)
|
||||||
|
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
|
||||||
|
|
||||||
|
// DeleteByGroupID 删除指定分组的所有用户专属条目(分组删除时调用)
|
||||||
DeleteByGroupID(ctx context.Context, groupID int64) error
|
DeleteByGroupID(ctx context.Context, groupID int64) error
|
||||||
|
|
||||||
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
|
// DeleteByUserID 删除指定用户的所有专属条目(用户删除时调用)
|
||||||
DeleteByUserID(ctx context.Context, userID int64) error
|
DeleteByUserID(ctx context.Context, userID int64) error
|
||||||
}
|
}
|
||||||
|
|||||||
25
backend/internal/service/user_rpm_cache.go
Normal file
25
backend/internal/service/user_rpm_cache.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// UserRPMCache 用户/分组级 RPM 计数器接口。
|
||||||
|
//
|
||||||
|
// 与账号级 RPMCache 的区别:
|
||||||
|
// - RPMCache —— 按外部 AI provider 账号聚合(key: rpm:{accountID}:{min})。
|
||||||
|
// - UserRPMCache —— 按用户或 (用户, 分组) 聚合,杜绝"同一用户创建多个 API Key 绕过 RPM"的路径。
|
||||||
|
// key 形如 rpm:ug:{userID}:{groupID}:{min} 或 rpm:u:{userID}:{min}。
|
||||||
|
type UserRPMCache interface {
|
||||||
|
// IncrementUserGroupRPM 原子递增 (user, group) 级分钟计数并返回最新值。
|
||||||
|
// 用于分组 rpm_limit 与 user-group rpm_override 两种命中分支。
|
||||||
|
IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error)
|
||||||
|
|
||||||
|
// IncrementUserRPM 原子递增用户级分钟计数并返回最新值。
|
||||||
|
// 用于用户全局 rpm_limit 兜底分支(分组未设且无 override 时)。
|
||||||
|
IncrementUserRPM(ctx context.Context, userID int64) (count int, err error)
|
||||||
|
|
||||||
|
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读,不递增)。
|
||||||
|
GetUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error)
|
||||||
|
|
||||||
|
// GetUserRPM 获取用户当前分钟已用 RPM(只读,不递增)。
|
||||||
|
GetUserRPM(ctx context.Context, userID int64) (count int, err error)
|
||||||
|
}
|
||||||
@@ -39,6 +39,11 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
|||||||
return NewEmailQueueService(emailService, 3)
|
return NewEmailQueueService(emailService, 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideOAuthRefreshAPI creates OAuthRefreshAPI with the default lock TTL.
|
||||||
|
func ProvideOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
|
||||||
|
return NewOAuthRefreshAPI(accountRepo, tokenCache)
|
||||||
|
}
|
||||||
|
|
||||||
// ProvideTokenRefreshService creates and starts TokenRefreshService
|
// ProvideTokenRefreshService creates and starts TokenRefreshService
|
||||||
func ProvideTokenRefreshService(
|
func ProvideTokenRefreshService(
|
||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
@@ -210,11 +215,13 @@ func ProvideRateLimitService(
|
|||||||
geminiQuotaService *GeminiQuotaService,
|
geminiQuotaService *GeminiQuotaService,
|
||||||
tempUnschedCache TempUnschedCache,
|
tempUnschedCache TempUnschedCache,
|
||||||
timeoutCounterCache TimeoutCounterCache,
|
timeoutCounterCache TimeoutCounterCache,
|
||||||
|
openAI403CounterCache OpenAI403CounterCache,
|
||||||
settingService *SettingService,
|
settingService *SettingService,
|
||||||
tokenCacheInvalidator TokenCacheInvalidator,
|
tokenCacheInvalidator TokenCacheInvalidator,
|
||||||
) *RateLimitService {
|
) *RateLimitService {
|
||||||
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
|
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
|
||||||
svc.SetTimeoutCounterCache(timeoutCounterCache)
|
svc.SetTimeoutCounterCache(timeoutCounterCache)
|
||||||
|
svc.SetOpenAI403CounterCache(openAI403CounterCache)
|
||||||
svc.SetSettingService(settingService)
|
svc.SetSettingService(settingService)
|
||||||
svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
|
svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
|
||||||
return svc
|
return svc
|
||||||
@@ -384,6 +391,19 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies.
|
||||||
|
func ProvideBillingCacheService(
|
||||||
|
cache BillingCache,
|
||||||
|
userRepo UserRepository,
|
||||||
|
subRepo UserSubscriptionRepository,
|
||||||
|
apiKeyRepo APIKeyRepository,
|
||||||
|
rpmCache UserRPMCache,
|
||||||
|
rateRepo UserGroupRateRepository,
|
||||||
|
cfg *config.Config,
|
||||||
|
) *BillingCacheService {
|
||||||
|
return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
// ProviderSet is the Wire provider set for all services
|
// ProviderSet is the Wire provider set for all services
|
||||||
var ProviderSet = wire.NewSet(
|
var ProviderSet = wire.NewSet(
|
||||||
// Core services
|
// Core services
|
||||||
@@ -400,7 +420,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewDashboardService,
|
NewDashboardService,
|
||||||
ProvidePricingService,
|
ProvidePricingService,
|
||||||
NewBillingService,
|
NewBillingService,
|
||||||
NewBillingCacheService,
|
ProvideBillingCacheService,
|
||||||
NewAnnouncementService,
|
NewAnnouncementService,
|
||||||
NewAdminService,
|
NewAdminService,
|
||||||
NewGatewayService,
|
NewGatewayService,
|
||||||
@@ -412,7 +432,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewCompositeTokenCacheInvalidator,
|
NewCompositeTokenCacheInvalidator,
|
||||||
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
||||||
NewAntigravityOAuthService,
|
NewAntigravityOAuthService,
|
||||||
NewOAuthRefreshAPI,
|
ProvideOAuthRefreshAPI,
|
||||||
ProvideGeminiTokenProvider,
|
ProvideGeminiTokenProvider,
|
||||||
NewGeminiMessagesCompatService,
|
NewGeminiMessagesCompatService,
|
||||||
ProvideAntigravityTokenProvider,
|
ProvideAntigravityTokenProvider,
|
||||||
|
|||||||
7
backend/migrations/125_add_group_rpm_limit.sql
Normal file
7
backend/migrations/125_add_group_rpm_limit.sql
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
-- Add per-group Requests-Per-Minute limit.
|
||||||
|
-- rpm_limit: 分组统一 RPM 上限(0 = 不限制)。
|
||||||
|
-- 一旦配置即接管该用户在该分组的限流,覆盖用户级 users.rpm_limit。
|
||||||
|
-- 计数键:rpm:ug:{user_id}:{group_id}:{minute}。
|
||||||
|
ALTER TABLE groups ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0;
|
||||||
|
|
||||||
|
COMMENT ON COLUMN groups.rpm_limit IS '分组 RPM 上限;0 表示不限制;设置后接管该分组用户的限流(覆盖用户级 rpm_limit)。';
|
||||||
7
backend/migrations/126_add_user_rpm_limit.sql
Normal file
7
backend/migrations/126_add_user_rpm_limit.sql
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
-- Add per-user Requests-Per-Minute cap.
|
||||||
|
-- rpm_limit: 用户全局 RPM 兜底(0 = 不限制)。
|
||||||
|
-- 仅当所访问分组未设置 rpm_limit 且无 user-group rpm_override 时作为兜底生效。
|
||||||
|
-- 计数键:rpm:u:{user_id}:{minute}。
|
||||||
|
ALTER TABLE users ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0;
|
||||||
|
|
||||||
|
COMMENT ON COLUMN users.rpm_limit IS '用户级 RPM 兜底上限;0 表示不限制;仅当分组未设置 rpm_limit 时生效。';
|
||||||
16
backend/migrations/127_add_user_group_rpm_override.sql
Normal file
16
backend/migrations/127_add_user_group_rpm_override.sql
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
-- 在已有的"用户专属分组倍率表"上扩展 rpm_override 列;同时放宽 rate_multiplier 为可空,
|
||||||
|
-- 使一行记录可以只覆盖 rate、只覆盖 rpm,或同时覆盖两者。
|
||||||
|
-- 语义:
|
||||||
|
-- - rate_multiplier NULL → 该用户在此分组使用 groups.rate_multiplier 默认值
|
||||||
|
-- - rate_multiplier 非 NULL → 覆盖分组默认计费倍率
|
||||||
|
-- - rpm_override NULL → 该用户在此分组使用 groups.rpm_limit 默认值
|
||||||
|
-- - rpm_override 非 NULL → 覆盖分组默认 RPM(0 = 不限制)
|
||||||
|
-- 用户级 users.rpm_limit 仍独立生效(跨分组总配额)。
|
||||||
|
ALTER TABLE user_group_rate_multipliers
|
||||||
|
ADD COLUMN IF NOT EXISTS rpm_override integer NULL;
|
||||||
|
|
||||||
|
ALTER TABLE user_group_rate_multipliers
|
||||||
|
ALTER COLUMN rate_multiplier DROP NOT NULL;
|
||||||
|
|
||||||
|
COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率;NULL 表示沿用分组默认倍率。';
|
||||||
|
COMMENT ON COLUMN user_group_rate_multipliers.rpm_override IS '专属 RPM 上限;NULL 表示沿用分组默认;0 表示该用户在此分组不受 RPM 限制。';
|
||||||
@@ -164,7 +164,8 @@ export interface GroupRateMultiplierEntry {
|
|||||||
user_email: string
|
user_email: string
|
||||||
user_notes: string
|
user_notes: string
|
||||||
user_status: string
|
user_status: string
|
||||||
rate_multiplier: number
|
rate_multiplier?: number | null
|
||||||
|
rpm_override?: number | null
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -205,9 +206,7 @@ export async function clearGroupRateMultipliers(id: number): Promise<{ message:
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Batch set rate multipliers for users in a group
|
* Batch set rate multipliers for users in a group
|
||||||
* @param id - Group ID
|
* Only touches rate_multiplier column; preserves rpm_override on existing rows.
|
||||||
* @param entries - Array of { user_id, rate_multiplier }
|
|
||||||
* @returns Success confirmation
|
|
||||||
*/
|
*/
|
||||||
export async function batchSetGroupRateMultipliers(
|
export async function batchSetGroupRateMultipliers(
|
||||||
id: number,
|
id: number,
|
||||||
@@ -220,6 +219,60 @@ export async function batchSetGroupRateMultipliers(
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RPM override entry for a user in a group
|
||||||
|
*/
|
||||||
|
export interface GroupRPMOverrideEntry {
|
||||||
|
user_id: number
|
||||||
|
user_name: string
|
||||||
|
user_email: string
|
||||||
|
user_notes: string
|
||||||
|
user_status: string
|
||||||
|
rpm_override: number
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get RPM overrides for users in a group (subset of rate-multipliers endpoint).
|
||||||
|
*/
|
||||||
|
export async function getGroupRPMOverrides(id: number): Promise<GroupRPMOverrideEntry[]> {
|
||||||
|
const { data } = await apiClient.get<GroupRateMultiplierEntry[]>(
|
||||||
|
`/admin/groups/${id}/rate-multipliers`
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
.filter(e => e.rpm_override != null)
|
||||||
|
.map(e => ({
|
||||||
|
user_id: e.user_id,
|
||||||
|
user_name: e.user_name,
|
||||||
|
user_email: e.user_email,
|
||||||
|
user_notes: e.user_notes,
|
||||||
|
user_status: e.user_status,
|
||||||
|
rpm_override: e.rpm_override as number
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Batch set RPM overrides for users in a group.
|
||||||
|
* Only touches rpm_override column; preserves rate_multiplier on existing rows.
|
||||||
|
*/
|
||||||
|
export async function batchSetGroupRPMOverrides(
|
||||||
|
id: number,
|
||||||
|
entries: Array<{ user_id: number; rpm_override: number }>
|
||||||
|
): Promise<{ message: string }> {
|
||||||
|
const { data } = await apiClient.put<{ message: string }>(
|
||||||
|
`/admin/groups/${id}/rpm-overrides`,
|
||||||
|
{ entries }
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear all RPM overrides for a group (preserves rate_multiplier).
|
||||||
|
*/
|
||||||
|
export async function clearGroupRPMOverrides(id: number): Promise<{ message: string }> {
|
||||||
|
const { data } = await apiClient.delete<{ message: string }>(`/admin/groups/${id}/rpm-overrides`)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get usage summary (today + cumulative cost) for all groups
|
* Get usage summary (today + cumulative cost) for all groups
|
||||||
* @param timezone - IANA timezone string (e.g. "Asia/Shanghai")
|
* @param timezone - IANA timezone string (e.g. "Asia/Shanghai")
|
||||||
@@ -262,6 +315,9 @@ export const groupsAPI = {
|
|||||||
getGroupRateMultipliers,
|
getGroupRateMultipliers,
|
||||||
clearGroupRateMultipliers,
|
clearGroupRateMultipliers,
|
||||||
batchSetGroupRateMultipliers,
|
batchSetGroupRateMultipliers,
|
||||||
|
getGroupRPMOverrides,
|
||||||
|
clearGroupRPMOverrides,
|
||||||
|
batchSetGroupRPMOverrides,
|
||||||
updateSortOrder,
|
updateSortOrder,
|
||||||
getUsageSummary,
|
getUsageSummary,
|
||||||
getCapacitySummary
|
getCapacitySummary
|
||||||
|
|||||||
@@ -309,6 +309,7 @@ export interface SystemSettings {
|
|||||||
// Default settings
|
// Default settings
|
||||||
default_balance: number;
|
default_balance: number;
|
||||||
default_concurrency: number;
|
default_concurrency: number;
|
||||||
|
default_user_rpm_limit: number;
|
||||||
default_subscriptions: DefaultSubscriptionSetting[];
|
default_subscriptions: DefaultSubscriptionSetting[];
|
||||||
auth_source_default_email_balance?: number;
|
auth_source_default_email_balance?: number;
|
||||||
auth_source_default_email_concurrency?: number;
|
auth_source_default_email_concurrency?: number;
|
||||||
@@ -489,6 +490,7 @@ export interface UpdateSettingsRequest {
|
|||||||
totp_enabled?: boolean; // TOTP 双因素认证
|
totp_enabled?: boolean; // TOTP 双因素认证
|
||||||
default_balance?: number;
|
default_balance?: number;
|
||||||
default_concurrency?: number;
|
default_concurrency?: number;
|
||||||
|
default_user_rpm_limit?: number;
|
||||||
default_subscriptions?: DefaultSubscriptionSetting[];
|
default_subscriptions?: DefaultSubscriptionSetting[];
|
||||||
auth_source_default_email_balance?: number;
|
auth_source_default_email_balance?: number;
|
||||||
auth_source_default_email_concurrency?: number;
|
auth_source_default_email_concurrency?: number;
|
||||||
|
|||||||
434
frontend/src/components/admin/group/GroupRPMOverridesModal.vue
Normal file
434
frontend/src/components/admin/group/GroupRPMOverridesModal.vue
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
<template>
|
||||||
|
<BaseDialog :show="show" :title="t('admin.groups.rpmOverridesTitle')" width="wide" @close="handleClose">
|
||||||
|
<div v-if="group" class="space-y-4">
|
||||||
|
<!-- 分组信息 -->
|
||||||
|
<div class="flex flex-wrap items-center gap-3 rounded-lg bg-gray-50 px-4 py-2.5 text-sm dark:bg-dark-700">
|
||||||
|
<span class="inline-flex items-center gap-1.5" :class="platformColorClass">
|
||||||
|
<PlatformIcon :platform="group.platform" size="sm" />
|
||||||
|
{{ t('admin.groups.platforms.' + group.platform) }}
|
||||||
|
</span>
|
||||||
|
<span class="text-gray-400">|</span>
|
||||||
|
<span class="font-medium text-gray-900 dark:text-white">{{ group.name }}</span>
|
||||||
|
<span class="text-gray-400">|</span>
|
||||||
|
<span class="text-gray-600 dark:text-gray-400">
|
||||||
|
{{ t('admin.groups.groupRpmDefault') }}: {{ group.rpm_limit || 0 }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 操作区:添加用户 -->
|
||||||
|
<div class="rounded-lg border border-gray-200 p-3 dark:border-dark-600">
|
||||||
|
<h4 class="mb-2 text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('admin.groups.addUserRpm') }}
|
||||||
|
</h4>
|
||||||
|
<div class="flex items-end gap-2">
|
||||||
|
<div class="relative flex-1">
|
||||||
|
<input
|
||||||
|
v-model="searchQuery"
|
||||||
|
type="text"
|
||||||
|
autocomplete="off"
|
||||||
|
class="input w-full"
|
||||||
|
:placeholder="t('admin.groups.searchUserPlaceholder')"
|
||||||
|
@input="handleSearchUsers"
|
||||||
|
@focus="showDropdown = true"
|
||||||
|
/>
|
||||||
|
<div
|
||||||
|
v-if="showDropdown && searchResults.length > 0"
|
||||||
|
class="absolute left-0 right-0 top-full z-10 mt-1 max-h-48 overflow-y-auto rounded-lg border border-gray-200 bg-white shadow-lg dark:border-dark-500 dark:bg-dark-700"
|
||||||
|
>
|
||||||
|
<button
|
||||||
|
v-for="user in searchResults"
|
||||||
|
:key="user.id"
|
||||||
|
type="button"
|
||||||
|
class="flex w-full items-center gap-2 px-3 py-1.5 text-left text-sm hover:bg-gray-50 dark:hover:bg-dark-600"
|
||||||
|
@click="selectUser(user)"
|
||||||
|
>
|
||||||
|
<span class="text-gray-400">#{{ user.id }}</span>
|
||||||
|
<span class="text-gray-900 dark:text-white">{{ user.username || user.email }}</span>
|
||||||
|
<span v-if="user.username" class="text-xs text-gray-400">{{ user.email }}</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="w-24">
|
||||||
|
<input
|
||||||
|
v-model.number="newRpm"
|
||||||
|
type="number"
|
||||||
|
step="1"
|
||||||
|
min="0"
|
||||||
|
autocomplete="off"
|
||||||
|
class="hide-spinner input w-full"
|
||||||
|
placeholder="100"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="btn btn-primary shrink-0"
|
||||||
|
:disabled="!selectedUser || newRpm == null || newRpm < 0"
|
||||||
|
@click="handleAddLocal"
|
||||||
|
>
|
||||||
|
{{ t('common.add') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-if="localEntries.length > 0" class="mt-3 flex items-center justify-end border-t border-gray-100 pt-3 dark:border-dark-600">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
:disabled="clearing"
|
||||||
|
class="rounded-lg border border-red-200 bg-red-50 px-3 py-1.5 text-sm font-medium text-red-600 transition-colors hover:bg-red-100 disabled:opacity-50 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400 dark:hover:bg-red-900/40"
|
||||||
|
@click="clearAllLocal"
|
||||||
|
>
|
||||||
|
<Icon v-if="clearing" name="refresh" size="sm" class="mr-1 inline animate-spin" />
|
||||||
|
{{ t('admin.groups.clearAll') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 加载状态 -->
|
||||||
|
<div v-if="loading" class="flex justify-center py-6">
|
||||||
|
<svg class="h-6 w-6 animate-spin text-primary-500" fill="none" viewBox="0 0 24 24">
|
||||||
|
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
||||||
|
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 列表 -->
|
||||||
|
<div v-else>
|
||||||
|
<h4 class="mb-2 text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('admin.groups.rpmOverrides') }} ({{ localEntries.length }})
|
||||||
|
</h4>
|
||||||
|
|
||||||
|
<div v-if="localEntries.length === 0" class="py-6 text-center text-sm text-gray-400 dark:text-gray-500">
|
||||||
|
{{ t('admin.groups.noRpmOverrides') }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else>
|
||||||
|
<div class="overflow-hidden rounded-lg border border-gray-200 dark:border-dark-600">
|
||||||
|
<div class="max-h-[420px] overflow-y-auto">
|
||||||
|
<table class="w-full text-sm">
|
||||||
|
<thead class="sticky top-0 z-[1]">
|
||||||
|
<tr class="border-b border-gray-200 bg-gray-50 dark:border-dark-600 dark:bg-dark-700">
|
||||||
|
<th class="px-3 py-2 text-left text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('admin.groups.columns.userEmail') }}</th>
|
||||||
|
<th class="px-3 py-2 text-left text-xs font-medium text-gray-500 dark:text-gray-400">ID</th>
|
||||||
|
<th class="px-3 py-2 text-left text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('admin.groups.columns.userName') }}</th>
|
||||||
|
<th class="px-3 py-2 text-left text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('admin.groups.columns.userNotes') }}</th>
|
||||||
|
<th class="px-3 py-2 text-left text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('admin.groups.columns.userStatus') }}</th>
|
||||||
|
<th class="px-3 py-2 text-left text-xs font-medium text-gray-500 dark:text-gray-400" :title="t('admin.groups.columns.rpmOverrideHint')">{{ t('admin.groups.columns.rpmOverride') }}</th>
|
||||||
|
<th class="w-10 px-2 py-2"></th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody class="divide-y divide-gray-100 dark:divide-dark-600">
|
||||||
|
<tr
|
||||||
|
v-for="entry in paginatedLocalEntries"
|
||||||
|
:key="entry.user_id"
|
||||||
|
class="hover:bg-gray-50 dark:hover:bg-dark-700/50"
|
||||||
|
>
|
||||||
|
<td class="px-3 py-2 text-gray-600 dark:text-gray-400">{{ entry.user_email }}</td>
|
||||||
|
<td class="whitespace-nowrap px-3 py-2 text-gray-400 dark:text-gray-500">{{ entry.user_id }}</td>
|
||||||
|
<td class="whitespace-nowrap px-3 py-2 text-gray-900 dark:text-white">{{ entry.user_name || '-' }}</td>
|
||||||
|
<td class="max-w-[160px] truncate px-3 py-2 text-gray-500 dark:text-gray-400" :title="entry.user_notes">{{ entry.user_notes || '-' }}</td>
|
||||||
|
<td class="whitespace-nowrap px-3 py-2">
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'inline-flex rounded-full px-2 py-0.5 text-xs font-medium',
|
||||||
|
entry.user_status === 'active'
|
||||||
|
? 'bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-400'
|
||||||
|
: 'bg-gray-100 text-gray-600 dark:bg-dark-600 dark:text-gray-400'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
{{ entry.user_status }}
|
||||||
|
</span>
|
||||||
|
</td>
|
||||||
|
<td class="whitespace-nowrap px-3 py-2">
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
step="1"
|
||||||
|
min="0"
|
||||||
|
autocomplete="off"
|
||||||
|
:value="entry.rpm_override"
|
||||||
|
class="hide-spinner w-20 rounded border border-gray-200 bg-white px-2 py-1 text-center text-sm font-medium transition-colors focus:border-primary-500 focus:outline-none focus:ring-1 focus:ring-primary-500/20 dark:border-dark-500 dark:bg-dark-700 dark:focus:border-primary-500"
|
||||||
|
@change="updateLocalRpm(entry.user_id, ($event.target as HTMLInputElement).value)"
|
||||||
|
/>
|
||||||
|
</td>
|
||||||
|
<td class="px-2 py-2">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="rounded p-1 text-gray-400 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20 dark:hover:text-red-400"
|
||||||
|
@click="removeLocal(entry.user_id)"
|
||||||
|
>
|
||||||
|
<Icon name="trash" size="sm" />
|
||||||
|
</button>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Pagination
|
||||||
|
:total="localEntries.length"
|
||||||
|
:page="currentPage"
|
||||||
|
:page-size="pageSize"
|
||||||
|
@update:page="currentPage = $event"
|
||||||
|
@update:pageSize="handlePageSizeChange"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 底部 -->
|
||||||
|
<div class="flex items-center gap-3 border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
|
<template v-if="isDirty">
|
||||||
|
<span class="text-xs text-amber-600 dark:text-amber-400">{{ t('admin.groups.unsavedChanges') }}</span>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="text-xs font-medium text-primary-600 hover:text-primary-700 dark:text-primary-400 dark:hover:text-primary-300"
|
||||||
|
@click="handleCancel"
|
||||||
|
>
|
||||||
|
{{ t('admin.groups.revertChanges') }}
|
||||||
|
</button>
|
||||||
|
</template>
|
||||||
|
<div class="ml-auto flex items-center gap-3">
|
||||||
|
<button type="button" class="btn btn-sm px-4 py-1.5" @click="handleClose">
|
||||||
|
{{ t('common.close') }}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
v-if="isDirty"
|
||||||
|
type="button"
|
||||||
|
class="btn btn-primary btn-sm px-4 py-1.5"
|
||||||
|
:disabled="saving"
|
||||||
|
@click="handleSave"
|
||||||
|
>
|
||||||
|
<Icon v-if="saving" name="refresh" size="sm" class="mr-1 animate-spin" />
|
||||||
|
{{ t('common.save') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</BaseDialog>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, computed, watch } from 'vue'
|
||||||
|
import { useI18n } from 'vue-i18n'
|
||||||
|
import { useAppStore } from '@/stores/app'
|
||||||
|
import { adminAPI } from '@/api/admin'
|
||||||
|
import type { GroupRPMOverrideEntry } from '@/api/admin/groups'
|
||||||
|
import type { AdminGroup, AdminUser } from '@/types'
|
||||||
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
|
import Pagination from '@/components/common/Pagination.vue'
|
||||||
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
|
import PlatformIcon from '@/components/common/PlatformIcon.vue'
|
||||||
|
|
||||||
|
interface LocalEntry extends GroupRPMOverrideEntry {}
|
||||||
|
|
||||||
|
const props = defineProps<{
|
||||||
|
show: boolean
|
||||||
|
group: AdminGroup | null
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const emit = defineEmits<{
|
||||||
|
close: []
|
||||||
|
success: []
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const { t } = useI18n()
|
||||||
|
const appStore = useAppStore()
|
||||||
|
|
||||||
|
const loading = ref(false)
|
||||||
|
const saving = ref(false)
|
||||||
|
const serverEntries = ref<GroupRPMOverrideEntry[]>([])
|
||||||
|
const localEntries = ref<LocalEntry[]>([])
|
||||||
|
const searchQuery = ref('')
|
||||||
|
const searchResults = ref<AdminUser[]>([])
|
||||||
|
const showDropdown = ref(false)
|
||||||
|
const selectedUser = ref<AdminUser | null>(null)
|
||||||
|
const newRpm = ref<number | null>(null)
|
||||||
|
const currentPage = ref(1)
|
||||||
|
const pageSize = ref(10)
|
||||||
|
|
||||||
|
let searchTimeout: ReturnType<typeof setTimeout>
|
||||||
|
|
||||||
|
const platformColorClass = computed(() => {
|
||||||
|
switch (props.group?.platform) {
|
||||||
|
case 'anthropic': return 'text-orange-700 dark:text-orange-400'
|
||||||
|
case 'openai': return 'text-emerald-700 dark:text-emerald-400'
|
||||||
|
case 'antigravity': return 'text-purple-700 dark:text-purple-400'
|
||||||
|
default: return 'text-blue-700 dark:text-blue-400'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const isDirty = computed(() => {
|
||||||
|
if (localEntries.value.length !== serverEntries.value.length) return true
|
||||||
|
const serverMap = new Map(serverEntries.value.map(e => [e.user_id, e.rpm_override]))
|
||||||
|
return localEntries.value.some(e => serverMap.get(e.user_id) !== e.rpm_override)
|
||||||
|
})
|
||||||
|
|
||||||
|
const paginatedLocalEntries = computed(() => {
|
||||||
|
const start = (currentPage.value - 1) * pageSize.value
|
||||||
|
return localEntries.value.slice(start, start + pageSize.value)
|
||||||
|
})
|
||||||
|
|
||||||
|
const cloneEntries = (entries: GroupRPMOverrideEntry[]): LocalEntry[] => {
|
||||||
|
return entries.map(e => ({ ...e }))
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadEntries = async () => {
|
||||||
|
if (!props.group) return
|
||||||
|
loading.value = true
|
||||||
|
try {
|
||||||
|
serverEntries.value = await adminAPI.groups.getGroupRPMOverrides(props.group.id)
|
||||||
|
localEntries.value = cloneEntries(serverEntries.value)
|
||||||
|
adjustPage()
|
||||||
|
} catch (error) {
|
||||||
|
appStore.showError(t('admin.groups.failedToLoad'))
|
||||||
|
console.error('Error loading RPM overrides:', error)
|
||||||
|
} finally {
|
||||||
|
loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const adjustPage = () => {
|
||||||
|
const totalPages = Math.max(1, Math.ceil(localEntries.value.length / pageSize.value))
|
||||||
|
if (currentPage.value > totalPages) currentPage.value = totalPages
|
||||||
|
}
|
||||||
|
|
||||||
|
watch(() => props.show, (val) => {
|
||||||
|
if (val && props.group) {
|
||||||
|
currentPage.value = 1
|
||||||
|
searchQuery.value = ''
|
||||||
|
searchResults.value = []
|
||||||
|
selectedUser.value = null
|
||||||
|
newRpm.value = null
|
||||||
|
loadEntries()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const handlePageSizeChange = (newSize: number) => {
|
||||||
|
pageSize.value = newSize
|
||||||
|
currentPage.value = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleSearchUsers = () => {
|
||||||
|
clearTimeout(searchTimeout)
|
||||||
|
selectedUser.value = null
|
||||||
|
if (!searchQuery.value.trim()) {
|
||||||
|
searchResults.value = []
|
||||||
|
showDropdown.value = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
searchTimeout = setTimeout(async () => {
|
||||||
|
try {
|
||||||
|
const res = await adminAPI.users.list(1, 10, { search: searchQuery.value.trim() })
|
||||||
|
searchResults.value = res.items
|
||||||
|
showDropdown.value = true
|
||||||
|
} catch {
|
||||||
|
searchResults.value = []
|
||||||
|
}
|
||||||
|
}, 300)
|
||||||
|
}
|
||||||
|
|
||||||
|
const selectUser = (user: AdminUser) => {
|
||||||
|
selectedUser.value = user
|
||||||
|
searchQuery.value = user.email
|
||||||
|
showDropdown.value = false
|
||||||
|
searchResults.value = []
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleAddLocal = () => {
|
||||||
|
if (!selectedUser.value || newRpm.value == null || newRpm.value < 0) return
|
||||||
|
const user = selectedUser.value
|
||||||
|
const idx = localEntries.value.findIndex(e => e.user_id === user.id)
|
||||||
|
const entry: LocalEntry = {
|
||||||
|
user_id: user.id,
|
||||||
|
user_name: user.username || '',
|
||||||
|
user_email: user.email,
|
||||||
|
user_notes: user.notes || '',
|
||||||
|
user_status: user.status || 'active',
|
||||||
|
rpm_override: newRpm.value
|
||||||
|
}
|
||||||
|
if (idx >= 0) {
|
||||||
|
localEntries.value[idx] = entry
|
||||||
|
} else {
|
||||||
|
localEntries.value.push(entry)
|
||||||
|
}
|
||||||
|
searchQuery.value = ''
|
||||||
|
selectedUser.value = null
|
||||||
|
newRpm.value = null
|
||||||
|
adjustPage()
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateLocalRpm = (userId: number, value: string) => {
|
||||||
|
const num = parseInt(value, 10)
|
||||||
|
if (isNaN(num) || num < 0) return
|
||||||
|
const entry = localEntries.value.find(e => e.user_id === userId)
|
||||||
|
if (entry) entry.rpm_override = num
|
||||||
|
}
|
||||||
|
|
||||||
|
const removeLocal = (userId: number) => {
|
||||||
|
localEntries.value = localEntries.value.filter(e => e.user_id !== userId)
|
||||||
|
adjustPage()
|
||||||
|
}
|
||||||
|
|
||||||
|
const clearing = ref(false)
|
||||||
|
const clearAllLocal = async () => {
|
||||||
|
if (!props.group || clearing.value) return
|
||||||
|
clearing.value = true
|
||||||
|
try {
|
||||||
|
await adminAPI.groups.clearGroupRPMOverrides(props.group.id)
|
||||||
|
localEntries.value = []
|
||||||
|
serverEntries.value = []
|
||||||
|
appStore.showSuccess(t('admin.groups.rpmSaved'))
|
||||||
|
} catch (error) {
|
||||||
|
appStore.showError(t('admin.groups.failedToSave'))
|
||||||
|
console.error('Error clearing RPM overrides:', error)
|
||||||
|
} finally {
|
||||||
|
clearing.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleCancel = () => {
|
||||||
|
localEntries.value = cloneEntries(serverEntries.value)
|
||||||
|
adjustPage()
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleSave = async () => {
|
||||||
|
if (!props.group) return
|
||||||
|
saving.value = true
|
||||||
|
try {
|
||||||
|
const entries = localEntries.value.map(e => ({
|
||||||
|
user_id: e.user_id,
|
||||||
|
rpm_override: e.rpm_override
|
||||||
|
}))
|
||||||
|
await adminAPI.groups.batchSetGroupRPMOverrides(props.group.id, entries)
|
||||||
|
appStore.showSuccess(t('admin.groups.rpmSaved'))
|
||||||
|
emit('success')
|
||||||
|
emit('close')
|
||||||
|
} catch (error) {
|
||||||
|
appStore.showError(t('admin.groups.failedToSave'))
|
||||||
|
console.error('Error saving RPM overrides:', error)
|
||||||
|
} finally {
|
||||||
|
saving.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleClose = () => {
|
||||||
|
if (isDirty.value) {
|
||||||
|
localEntries.value = cloneEntries(serverEntries.value)
|
||||||
|
}
|
||||||
|
emit('close')
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleClickOutside = () => { showDropdown.value = false }
|
||||||
|
if (typeof document !== 'undefined') {
|
||||||
|
document.addEventListener('click', handleClickOutside)
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.hide-spinner::-webkit-outer-spin-button,
|
||||||
|
.hide-spinner::-webkit-inner-spin-button {
|
||||||
|
-webkit-appearance: none;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
.hide-spinner {
|
||||||
|
-moz-appearance: textfield;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -168,7 +168,8 @@
|
|||||||
step="0.001"
|
step="0.001"
|
||||||
min="0.001"
|
min="0.001"
|
||||||
autocomplete="off"
|
autocomplete="off"
|
||||||
:value="entry.rate_multiplier"
|
:value="entry.rate_multiplier ?? ''"
|
||||||
|
:placeholder="String(props.group?.rate_multiplier ?? 1)"
|
||||||
class="hide-spinner w-20 rounded border border-gray-200 bg-white px-2 py-1 text-center text-sm font-medium transition-colors focus:border-primary-500 focus:outline-none focus:ring-1 focus:ring-primary-500/20 dark:border-dark-500 dark:bg-dark-700 dark:focus:border-primary-500"
|
class="hide-spinner w-20 rounded border border-gray-200 bg-white px-2 py-1 text-center text-sm font-medium transition-colors focus:border-primary-500 focus:outline-none focus:ring-1 focus:ring-primary-500/20 dark:border-dark-500 dark:bg-dark-700 dark:focus:border-primary-500"
|
||||||
@change="updateLocalRate(entry.user_id, ($event.target as HTMLInputElement).value)"
|
@change="updateLocalRate(entry.user_id, ($event.target as HTMLInputElement).value)"
|
||||||
/>
|
/>
|
||||||
@@ -294,19 +295,17 @@ const showFinalRate = computed(() => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// 计算最终倍率预览
|
// 计算最终倍率预览
|
||||||
const computeFinalRate = (rate: number) => {
|
const computeFinalRate = (rate: number | null | undefined) => {
|
||||||
if (!batchFactor.value) return rate
|
const base = rate ?? props.group?.rate_multiplier ?? 1
|
||||||
return parseFloat((rate * batchFactor.value).toFixed(6))
|
if (!batchFactor.value) return base
|
||||||
|
return parseFloat((base * batchFactor.value).toFixed(6))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检测是否有未保存的修改
|
// 检测是否有未保存的修改
|
||||||
const isDirty = computed(() => {
|
const isDirty = computed(() => {
|
||||||
if (localEntries.value.length !== serverEntries.value.length) return true
|
if (localEntries.value.length !== serverEntries.value.length) return true
|
||||||
const serverMap = new Map(serverEntries.value.map(e => [e.user_id, e.rate_multiplier]))
|
const serverMap = new Map(serverEntries.value.map(e => [e.user_id, e.rate_multiplier ?? null]))
|
||||||
return localEntries.value.some(e => {
|
return localEntries.value.some(e => serverMap.get(e.user_id) !== (e.rate_multiplier ?? null))
|
||||||
const serverRate = serverMap.get(e.user_id)
|
|
||||||
return serverRate === undefined || serverRate !== e.rate_multiplier
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
const paginatedLocalEntries = computed(() => {
|
const paginatedLocalEntries = computed(() => {
|
||||||
@@ -322,7 +321,9 @@ const loadEntries = async () => {
|
|||||||
if (!props.group) return
|
if (!props.group) return
|
||||||
loading.value = true
|
loading.value = true
|
||||||
try {
|
try {
|
||||||
serverEntries.value = await adminAPI.groups.getGroupRateMultipliers(props.group.id)
|
const raw = await adminAPI.groups.getGroupRateMultipliers(props.group.id)
|
||||||
|
// 仅显示已设置 rate_multiplier 的条目;rpm_override 在另一个弹窗管理,保留不动
|
||||||
|
serverEntries.value = raw.filter(e => e.rate_multiplier != null)
|
||||||
localEntries.value = cloneEntries(serverEntries.value)
|
localEntries.value = cloneEntries(serverEntries.value)
|
||||||
adjustPage()
|
adjustPage()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -394,7 +395,8 @@ const handleAddLocal = () => {
|
|||||||
user_email: user.email,
|
user_email: user.email,
|
||||||
user_notes: user.notes || '',
|
user_notes: user.notes || '',
|
||||||
user_status: user.status || 'active',
|
user_status: user.status || 'active',
|
||||||
rate_multiplier: newRate.value
|
rate_multiplier: newRate.value,
|
||||||
|
rpm_override: null
|
||||||
}
|
}
|
||||||
if (idx >= 0) {
|
if (idx >= 0) {
|
||||||
localEntries.value[idx] = entry
|
localEntries.value[idx] = entry
|
||||||
@@ -409,12 +411,15 @@ const handleAddLocal = () => {
|
|||||||
|
|
||||||
// 本地修改倍率
|
// 本地修改倍率
|
||||||
const updateLocalRate = (userId: number, value: string) => {
|
const updateLocalRate = (userId: number, value: string) => {
|
||||||
|
const entry = localEntries.value.find(e => e.user_id === userId)
|
||||||
|
if (!entry) return
|
||||||
|
if (value.trim() === '') {
|
||||||
|
entry.rate_multiplier = null
|
||||||
|
return
|
||||||
|
}
|
||||||
const num = parseFloat(value)
|
const num = parseFloat(value)
|
||||||
if (isNaN(num)) return
|
if (isNaN(num)) return
|
||||||
const entry = localEntries.value.find(e => e.user_id === userId)
|
entry.rate_multiplier = num
|
||||||
if (entry) {
|
|
||||||
entry.rate_multiplier = num
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 本地删除
|
// 本地删除
|
||||||
@@ -427,7 +432,9 @@ const removeLocal = (userId: number) => {
|
|||||||
const applyBatchFactor = () => {
|
const applyBatchFactor = () => {
|
||||||
if (!batchFactor.value || batchFactor.value <= 0) return
|
if (!batchFactor.value || batchFactor.value <= 0) return
|
||||||
for (const entry of localEntries.value) {
|
for (const entry of localEntries.value) {
|
||||||
entry.rate_multiplier = parseFloat((entry.rate_multiplier * batchFactor.value).toFixed(6))
|
if (entry.rate_multiplier != null) {
|
||||||
|
entry.rate_multiplier = parseFloat((entry.rate_multiplier * batchFactor.value).toFixed(6))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
batchFactor.value = null
|
batchFactor.value = null
|
||||||
}
|
}
|
||||||
@@ -444,15 +451,17 @@ const handleCancel = () => {
|
|||||||
adjustPage()
|
adjustPage()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存:一次性提交所有数据
|
// 保存:一次性提交所有数据(只提交 rate_multiplier;rpm_override 由独立弹窗管理)
|
||||||
const handleSave = async () => {
|
const handleSave = async () => {
|
||||||
if (!props.group) return
|
if (!props.group) return
|
||||||
saving.value = true
|
saving.value = true
|
||||||
try {
|
try {
|
||||||
const entries = localEntries.value.map(e => ({
|
const entries = localEntries.value
|
||||||
user_id: e.user_id,
|
.filter(e => e.rate_multiplier != null)
|
||||||
rate_multiplier: e.rate_multiplier
|
.map(e => ({
|
||||||
}))
|
user_id: e.user_id,
|
||||||
|
rate_multiplier: e.rate_multiplier as number
|
||||||
|
}))
|
||||||
await adminAPI.groups.batchSetGroupRateMultipliers(props.group.id, entries)
|
await adminAPI.groups.batchSetGroupRateMultipliers(props.group.id, entries)
|
||||||
appStore.showSuccess(t('admin.groups.rateSaved'))
|
appStore.showSuccess(t('admin.groups.rateSaved'))
|
||||||
emit('success')
|
emit('success')
|
||||||
|
|||||||
@@ -35,6 +35,18 @@
|
|||||||
<input v-model.number="form.concurrency" type="number" class="input" />
|
<input v-model.number="form.concurrency" type="number" class="input" />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label">{{ t('admin.users.form.rpmLimit') }}</label>
|
||||||
|
<input
|
||||||
|
v-model.number="form.rpm_limit"
|
||||||
|
type="number"
|
||||||
|
min="0"
|
||||||
|
step="1"
|
||||||
|
class="input"
|
||||||
|
:placeholder="t('admin.users.form.rpmLimitPlaceholder')"
|
||||||
|
/>
|
||||||
|
<p class="input-hint">{{ t('admin.users.form.rpmLimitHint') }}</p>
|
||||||
|
</div>
|
||||||
</form>
|
</form>
|
||||||
<template #footer>
|
<template #footer>
|
||||||
<div class="flex justify-end gap-3">
|
<div class="flex justify-end gap-3">
|
||||||
@@ -57,7 +69,7 @@ import Icon from '@/components/icons/Icon.vue'
|
|||||||
const props = defineProps<{ show: boolean }>()
|
const props = defineProps<{ show: boolean }>()
|
||||||
const emit = defineEmits(['close', 'success']); const { t } = useI18n()
|
const emit = defineEmits(['close', 'success']); const { t } = useI18n()
|
||||||
|
|
||||||
const form = reactive({ email: '', password: '', username: '', notes: '', balance: 0, concurrency: 1 })
|
const form = reactive({ email: '', password: '', username: '', notes: '', balance: 0, concurrency: 1, rpm_limit: 0 })
|
||||||
|
|
||||||
const { loading, submit } = useForm({
|
const { loading, submit } = useForm({
|
||||||
form,
|
form,
|
||||||
@@ -68,7 +80,7 @@ const { loading, submit } = useForm({
|
|||||||
successMsg: t('admin.users.userCreated')
|
successMsg: t('admin.users.userCreated')
|
||||||
})
|
})
|
||||||
|
|
||||||
watch(() => props.show, (v) => { if(v) Object.assign(form, { email: '', password: '', username: '', notes: '', balance: 0, concurrency: 1 }) })
|
watch(() => props.show, (v) => { if(v) Object.assign(form, { email: '', password: '', username: '', notes: '', balance: 0, concurrency: 1, rpm_limit: 0 }) })
|
||||||
|
|
||||||
const generateRandomPassword = () => {
|
const generateRandomPassword = () => {
|
||||||
const chars = 'ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjkmnpqrstuvwxyz23456789!@#$%^&*'
|
const chars = 'ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjkmnpqrstuvwxyz23456789!@#$%^&*'
|
||||||
|
|||||||
@@ -37,6 +37,18 @@
|
|||||||
<label class="input-label">{{ t('admin.users.columns.concurrency') }}</label>
|
<label class="input-label">{{ t('admin.users.columns.concurrency') }}</label>
|
||||||
<input v-model.number="form.concurrency" type="number" class="input" />
|
<input v-model.number="form.concurrency" type="number" class="input" />
|
||||||
</div>
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label">{{ t('admin.users.form.rpmLimit') }}</label>
|
||||||
|
<input
|
||||||
|
v-model.number="form.rpm_limit"
|
||||||
|
type="number"
|
||||||
|
min="0"
|
||||||
|
step="1"
|
||||||
|
class="input"
|
||||||
|
:placeholder="t('admin.users.form.rpmLimitPlaceholder')"
|
||||||
|
/>
|
||||||
|
<p class="input-hint">{{ t('admin.users.form.rpmLimitHint') }}</p>
|
||||||
|
</div>
|
||||||
<UserAttributeForm v-model="form.customAttributes" :user-id="user?.id" />
|
<UserAttributeForm v-model="form.customAttributes" :user-id="user?.id" />
|
||||||
</form>
|
</form>
|
||||||
<template #footer>
|
<template #footer>
|
||||||
@@ -66,11 +78,11 @@ const emit = defineEmits(['close', 'success'])
|
|||||||
const { t } = useI18n(); const appStore = useAppStore(); const { copyToClipboard } = useClipboard()
|
const { t } = useI18n(); const appStore = useAppStore(); const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
const submitting = ref(false); const passwordCopied = ref(false)
|
const submitting = ref(false); const passwordCopied = ref(false)
|
||||||
const form = reactive({ email: '', password: '', username: '', notes: '', concurrency: 1, customAttributes: {} as UserAttributeValuesMap })
|
const form = reactive({ email: '', password: '', username: '', notes: '', concurrency: 1, rpm_limit: 0, customAttributes: {} as UserAttributeValuesMap })
|
||||||
|
|
||||||
watch(() => props.user, (u) => {
|
watch(() => props.user, (u) => {
|
||||||
if (u) {
|
if (u) {
|
||||||
Object.assign(form, { email: u.email, password: '', username: u.username || '', notes: u.notes || '', concurrency: u.concurrency, customAttributes: {} })
|
Object.assign(form, { email: u.email, password: '', username: u.username || '', notes: u.notes || '', concurrency: u.concurrency, rpm_limit: u.rpm_limit ?? 0, customAttributes: {} })
|
||||||
passwordCopied.value = false
|
passwordCopied.value = false
|
||||||
}
|
}
|
||||||
}, { immediate: true })
|
}, { immediate: true })
|
||||||
@@ -97,7 +109,7 @@ const handleUpdateUser = async () => {
|
|||||||
}
|
}
|
||||||
submitting.value = true
|
submitting.value = true
|
||||||
try {
|
try {
|
||||||
const data: any = { email: form.email, username: form.username, notes: form.notes, concurrency: form.concurrency }
|
const data: any = { email: form.email, username: form.username, notes: form.notes, concurrency: form.concurrency, rpm_limit: form.rpm_limit }
|
||||||
if (form.password.trim()) data.password = form.password.trim()
|
if (form.password.trim()) data.password = form.password.trim()
|
||||||
await adminAPI.users.update(props.user.id, data)
|
await adminAPI.users.update(props.user.id, data)
|
||||||
if (Object.keys(form.customAttributes).length > 0) await adminAPI.userAttributes.updateUserAttributeValues(props.user.id, form.customAttributes)
|
if (Object.keys(form.customAttributes).length > 0) await adminAPI.userAttributes.updateUserAttributeValues(props.user.id, form.customAttributes)
|
||||||
|
|||||||
@@ -633,6 +633,22 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
|
|||||||
xhigh: {}
|
xhigh: {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
'gpt-5.5': {
|
||||||
|
name: 'GPT-5.5',
|
||||||
|
limit: {
|
||||||
|
context: 1050000,
|
||||||
|
output: 128000
|
||||||
|
},
|
||||||
|
options: {
|
||||||
|
store: false
|
||||||
|
},
|
||||||
|
variants: {
|
||||||
|
low: {},
|
||||||
|
medium: {},
|
||||||
|
high: {},
|
||||||
|
xhigh: {}
|
||||||
|
}
|
||||||
|
},
|
||||||
'gpt-5.4': {
|
'gpt-5.4': {
|
||||||
name: 'GPT-5.4',
|
name: 'GPT-5.4',
|
||||||
limit: {
|
limit: {
|
||||||
|
|||||||
@@ -6,6 +6,19 @@ vi.mock('@/stores/app', () => ({
|
|||||||
})
|
})
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
vi.mock('vue-i18n', () => ({
|
||||||
|
useI18n: () => ({
|
||||||
|
t: (key: string) => {
|
||||||
|
const messages: Record<string, string> = {
|
||||||
|
'admin.accounts.oauth.openai.failedToExchangeCode': 'OpenAI 授权码兑换失败',
|
||||||
|
'admin.accounts.oauth.openai.errors.OPENAI_OAUTH_PROXY_REQUIRED':
|
||||||
|
'未设置代理,当前服务器无法直连 OpenAI,导致 OpenAI OAuth 请求失败。请先选择可访问 OpenAI 的代理后重试;如果授权码已失效,请重新生成授权链接。'
|
||||||
|
}
|
||||||
|
return messages[key] ?? key
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
vi.mock('@/api/admin', () => ({
|
vi.mock('@/api/admin', () => ({
|
||||||
adminAPI: {
|
adminAPI: {
|
||||||
accounts: {
|
accounts: {
|
||||||
@@ -17,6 +30,7 @@ vi.mock('@/api/admin', () => ({
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
|
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
|
||||||
|
import { adminAPI } from '@/api/admin'
|
||||||
|
|
||||||
describe('useOpenAIOAuth.buildCredentials', () => {
|
describe('useOpenAIOAuth.buildCredentials', () => {
|
||||||
it('should keep client_id when token response contains it', () => {
|
it('should keep client_id when token response contains it', () => {
|
||||||
@@ -46,3 +60,21 @@ describe('useOpenAIOAuth.buildCredentials', () => {
|
|||||||
expect(creds.refresh_token).toBe('rt')
|
expect(creds.refresh_token).toBe('rt')
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('useOpenAIOAuth.exchangeAuthCode', () => {
|
||||||
|
it('shows a clear proxy hint when code exchange fails without a proxy', async () => {
|
||||||
|
vi.mocked(adminAPI.accounts.exchangeCode).mockRejectedValueOnce({
|
||||||
|
status: 502,
|
||||||
|
reason: 'OPENAI_OAUTH_PROXY_REQUIRED',
|
||||||
|
message: 'OpenAI OAuth token exchange failed: no proxy is configured.'
|
||||||
|
})
|
||||||
|
const oauth = useOpenAIOAuth()
|
||||||
|
|
||||||
|
const tokenInfo = await oauth.exchangeAuthCode('code', 'session-id', 'state')
|
||||||
|
|
||||||
|
expect(tokenInfo).toBeNull()
|
||||||
|
expect(oauth.error.value).toBe(
|
||||||
|
'未设置代理,当前服务器无法直连 OpenAI,导致 OpenAI OAuth 请求失败。请先选择可访问 OpenAI 的代理后重试;如果授权码已失效,请重新生成授权链接。'
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ const openaiModels = [
|
|||||||
// GPT-5.2 系列
|
// GPT-5.2 系列
|
||||||
'gpt-5.2', 'gpt-5.2-2025-12-11', 'gpt-5.2-chat-latest',
|
'gpt-5.2', 'gpt-5.2-2025-12-11', 'gpt-5.2-chat-latest',
|
||||||
'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11',
|
'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11',
|
||||||
|
// GPT-5.5 系列
|
||||||
|
'gpt-5.5',
|
||||||
// GPT-5.4 系列
|
// GPT-5.4 系列
|
||||||
'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-2026-03-05',
|
'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-2026-03-05',
|
||||||
// GPT-5.3 系列
|
// GPT-5.3 系列
|
||||||
@@ -260,6 +262,7 @@ const openaiPresetMappings = [
|
|||||||
{ label: 'o3', from: 'o3', to: 'o3', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
{ label: 'o3', from: 'o3', to: 'o3', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
||||||
{ label: 'GPT-5.3 Codex Spark', from: 'gpt-5.3-codex-spark', to: 'gpt-5.3-codex-spark', color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' },
|
{ label: 'GPT-5.3 Codex Spark', from: 'gpt-5.3-codex-spark', to: 'gpt-5.3-codex-spark', color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' },
|
||||||
{ label: 'GPT-5.2', from: 'gpt-5.2', to: 'gpt-5.2', color: 'bg-red-100 text-red-700 hover:bg-red-200 dark:bg-red-900/30 dark:text-red-400' },
|
{ label: 'GPT-5.2', from: 'gpt-5.2', to: 'gpt-5.2', color: 'bg-red-100 text-red-700 hover:bg-red-200 dark:bg-red-900/30 dark:text-red-400' },
|
||||||
|
{ label: 'GPT-5.5', from: 'gpt-5.5', to: 'gpt-5.5', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' },
|
||||||
{ label: 'GPT-5.4', from: 'gpt-5.4', to: 'gpt-5.4', color: 'bg-rose-100 text-rose-700 hover:bg-rose-200 dark:bg-rose-900/30 dark:text-rose-400' },
|
{ label: 'GPT-5.4', from: 'gpt-5.4', to: 'gpt-5.4', color: 'bg-rose-100 text-rose-700 hover:bg-rose-200 dark:bg-rose-900/30 dark:text-rose-400' },
|
||||||
{ label: 'Haiku→5.4', from: 'claude-haiku-4-5-20251001', to: 'gpt-5.4', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
{ label: 'Haiku→5.4', from: 'claude-haiku-4-5-20251001', to: 'gpt-5.4', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
||||||
{ label: 'Opus→5.4', from: 'claude-opus-4-6', to: 'gpt-5.4', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
{ label: 'Opus→5.4', from: 'claude-opus-4-6', to: 'gpt-5.4', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user