mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-16 21:04:45 +08:00
merge: 合并 upstream/main
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -77,6 +77,7 @@ temp/
|
|||||||
*.temp
|
*.temp
|
||||||
*.log
|
*.log
|
||||||
*.bak
|
*.bak
|
||||||
|
.cache/
|
||||||
|
|
||||||
# ===================
|
# ===================
|
||||||
# 构建产物
|
# 构建产物
|
||||||
|
|||||||
@@ -87,9 +87,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||||
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
|
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
|
||||||
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||||
|
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService)
|
||||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService)
|
||||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
@@ -99,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
|
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||||
@@ -116,7 +117,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
systemHandler := handler.ProvideSystemHandler(updateService)
|
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||||
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
|
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
|
||||||
|
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||||
|
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||||
|
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
||||||
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
|
||||||
pricingRemoteClient := repository.NewPricingRemoteClient()
|
pricingRemoteClient := repository.NewPricingRemoteClient()
|
||||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -127,10 +132,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
timingWheelService := service.ProvideTimingWheelService()
|
timingWheelService := service.ProvideTimingWheelService()
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
|
|
||||||
stdsql "database/sql"
|
stdsql "database/sql"
|
||||||
@@ -55,6 +57,10 @@ type Client struct {
|
|||||||
User *UserClient
|
User *UserClient
|
||||||
// UserAllowedGroup is the client for interacting with the UserAllowedGroup builders.
|
// UserAllowedGroup is the client for interacting with the UserAllowedGroup builders.
|
||||||
UserAllowedGroup *UserAllowedGroupClient
|
UserAllowedGroup *UserAllowedGroupClient
|
||||||
|
// UserAttributeDefinition is the client for interacting with the UserAttributeDefinition builders.
|
||||||
|
UserAttributeDefinition *UserAttributeDefinitionClient
|
||||||
|
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
|
||||||
|
UserAttributeValue *UserAttributeValueClient
|
||||||
// UserSubscription is the client for interacting with the UserSubscription builders.
|
// UserSubscription is the client for interacting with the UserSubscription builders.
|
||||||
UserSubscription *UserSubscriptionClient
|
UserSubscription *UserSubscriptionClient
|
||||||
}
|
}
|
||||||
@@ -78,6 +84,8 @@ func (c *Client) init() {
|
|||||||
c.UsageLog = NewUsageLogClient(c.config)
|
c.UsageLog = NewUsageLogClient(c.config)
|
||||||
c.User = NewUserClient(c.config)
|
c.User = NewUserClient(c.config)
|
||||||
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
|
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
|
||||||
|
c.UserAttributeDefinition = NewUserAttributeDefinitionClient(c.config)
|
||||||
|
c.UserAttributeValue = NewUserAttributeValueClient(c.config)
|
||||||
c.UserSubscription = NewUserSubscriptionClient(c.config)
|
c.UserSubscription = NewUserSubscriptionClient(c.config)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,19 +177,21 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
|
|||||||
cfg := c.config
|
cfg := c.config
|
||||||
cfg.driver = tx
|
cfg.driver = tx
|
||||||
return &Tx{
|
return &Tx{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
config: cfg,
|
config: cfg,
|
||||||
Account: NewAccountClient(cfg),
|
Account: NewAccountClient(cfg),
|
||||||
AccountGroup: NewAccountGroupClient(cfg),
|
AccountGroup: NewAccountGroupClient(cfg),
|
||||||
ApiKey: NewApiKeyClient(cfg),
|
ApiKey: NewApiKeyClient(cfg),
|
||||||
Group: NewGroupClient(cfg),
|
Group: NewGroupClient(cfg),
|
||||||
Proxy: NewProxyClient(cfg),
|
Proxy: NewProxyClient(cfg),
|
||||||
RedeemCode: NewRedeemCodeClient(cfg),
|
RedeemCode: NewRedeemCodeClient(cfg),
|
||||||
Setting: NewSettingClient(cfg),
|
Setting: NewSettingClient(cfg),
|
||||||
UsageLog: NewUsageLogClient(cfg),
|
UsageLog: NewUsageLogClient(cfg),
|
||||||
User: NewUserClient(cfg),
|
User: NewUserClient(cfg),
|
||||||
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
|
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
|
||||||
UserSubscription: NewUserSubscriptionClient(cfg),
|
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
|
||||||
|
UserAttributeValue: NewUserAttributeValueClient(cfg),
|
||||||
|
UserSubscription: NewUserSubscriptionClient(cfg),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,19 +209,21 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
|
|||||||
cfg := c.config
|
cfg := c.config
|
||||||
cfg.driver = &txDriver{tx: tx, drv: c.driver}
|
cfg.driver = &txDriver{tx: tx, drv: c.driver}
|
||||||
return &Tx{
|
return &Tx{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
config: cfg,
|
config: cfg,
|
||||||
Account: NewAccountClient(cfg),
|
Account: NewAccountClient(cfg),
|
||||||
AccountGroup: NewAccountGroupClient(cfg),
|
AccountGroup: NewAccountGroupClient(cfg),
|
||||||
ApiKey: NewApiKeyClient(cfg),
|
ApiKey: NewApiKeyClient(cfg),
|
||||||
Group: NewGroupClient(cfg),
|
Group: NewGroupClient(cfg),
|
||||||
Proxy: NewProxyClient(cfg),
|
Proxy: NewProxyClient(cfg),
|
||||||
RedeemCode: NewRedeemCodeClient(cfg),
|
RedeemCode: NewRedeemCodeClient(cfg),
|
||||||
Setting: NewSettingClient(cfg),
|
Setting: NewSettingClient(cfg),
|
||||||
UsageLog: NewUsageLogClient(cfg),
|
UsageLog: NewUsageLogClient(cfg),
|
||||||
User: NewUserClient(cfg),
|
User: NewUserClient(cfg),
|
||||||
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
|
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
|
||||||
UserSubscription: NewUserSubscriptionClient(cfg),
|
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
|
||||||
|
UserAttributeValue: NewUserAttributeValueClient(cfg),
|
||||||
|
UserSubscription: NewUserSubscriptionClient(cfg),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,7 +254,8 @@ func (c *Client) Close() error {
|
|||||||
func (c *Client) Use(hooks ...Hook) {
|
func (c *Client) Use(hooks ...Hook) {
|
||||||
for _, n := range []interface{ Use(...Hook) }{
|
for _, n := range []interface{ Use(...Hook) }{
|
||||||
c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting,
|
c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting,
|
||||||
c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription,
|
c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition,
|
||||||
|
c.UserAttributeValue, c.UserSubscription,
|
||||||
} {
|
} {
|
||||||
n.Use(hooks...)
|
n.Use(hooks...)
|
||||||
}
|
}
|
||||||
@@ -253,7 +266,8 @@ func (c *Client) Use(hooks ...Hook) {
|
|||||||
func (c *Client) Intercept(interceptors ...Interceptor) {
|
func (c *Client) Intercept(interceptors ...Interceptor) {
|
||||||
for _, n := range []interface{ Intercept(...Interceptor) }{
|
for _, n := range []interface{ Intercept(...Interceptor) }{
|
||||||
c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting,
|
c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting,
|
||||||
c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription,
|
c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition,
|
||||||
|
c.UserAttributeValue, c.UserSubscription,
|
||||||
} {
|
} {
|
||||||
n.Intercept(interceptors...)
|
n.Intercept(interceptors...)
|
||||||
}
|
}
|
||||||
@@ -282,6 +296,10 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
|
|||||||
return c.User.mutate(ctx, m)
|
return c.User.mutate(ctx, m)
|
||||||
case *UserAllowedGroupMutation:
|
case *UserAllowedGroupMutation:
|
||||||
return c.UserAllowedGroup.mutate(ctx, m)
|
return c.UserAllowedGroup.mutate(ctx, m)
|
||||||
|
case *UserAttributeDefinitionMutation:
|
||||||
|
return c.UserAttributeDefinition.mutate(ctx, m)
|
||||||
|
case *UserAttributeValueMutation:
|
||||||
|
return c.UserAttributeValue.mutate(ctx, m)
|
||||||
case *UserSubscriptionMutation:
|
case *UserSubscriptionMutation:
|
||||||
return c.UserSubscription.mutate(ctx, m)
|
return c.UserSubscription.mutate(ctx, m)
|
||||||
default:
|
default:
|
||||||
@@ -1916,6 +1934,22 @@ func (c *UserClient) QueryUsageLogs(_m *User) *UsageLogQuery {
|
|||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryAttributeValues queries the attribute_values edge of a User.
|
||||||
|
func (c *UserClient) QueryAttributeValues(_m *User) *UserAttributeValueQuery {
|
||||||
|
query := (&UserAttributeValueClient{config: c.config}).Query()
|
||||||
|
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||||
|
id := _m.ID
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(user.Table, user.FieldID, id),
|
||||||
|
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.O2M, false, user.AttributeValuesTable, user.AttributeValuesColumn),
|
||||||
|
)
|
||||||
|
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||||
|
return fromV, nil
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
|
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
|
||||||
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
|
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
|
||||||
query := (&UserAllowedGroupClient{config: c.config}).Query()
|
query := (&UserAllowedGroupClient{config: c.config}).Query()
|
||||||
@@ -2075,6 +2109,322 @@ func (c *UserAllowedGroupClient) mutate(ctx context.Context, m *UserAllowedGroup
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserAttributeDefinitionClient is a client for the UserAttributeDefinition schema.
|
||||||
|
type UserAttributeDefinitionClient struct {
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserAttributeDefinitionClient returns a client for the UserAttributeDefinition from the given config.
|
||||||
|
func NewUserAttributeDefinitionClient(c config) *UserAttributeDefinitionClient {
|
||||||
|
return &UserAttributeDefinitionClient{config: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use adds a list of mutation hooks to the hooks stack.
|
||||||
|
// A call to `Use(f, g, h)` equals to `userattributedefinition.Hooks(f(g(h())))`.
|
||||||
|
func (c *UserAttributeDefinitionClient) Use(hooks ...Hook) {
|
||||||
|
c.hooks.UserAttributeDefinition = append(c.hooks.UserAttributeDefinition, hooks...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intercept adds a list of query interceptors to the interceptors stack.
|
||||||
|
// A call to `Intercept(f, g, h)` equals to `userattributedefinition.Intercept(f(g(h())))`.
|
||||||
|
func (c *UserAttributeDefinitionClient) Intercept(interceptors ...Interceptor) {
|
||||||
|
c.inters.UserAttributeDefinition = append(c.inters.UserAttributeDefinition, interceptors...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create returns a builder for creating a UserAttributeDefinition entity.
|
||||||
|
func (c *UserAttributeDefinitionClient) Create() *UserAttributeDefinitionCreate {
|
||||||
|
mutation := newUserAttributeDefinitionMutation(c.config, OpCreate)
|
||||||
|
return &UserAttributeDefinitionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateBulk returns a builder for creating a bulk of UserAttributeDefinition entities.
|
||||||
|
func (c *UserAttributeDefinitionClient) CreateBulk(builders ...*UserAttributeDefinitionCreate) *UserAttributeDefinitionCreateBulk {
|
||||||
|
return &UserAttributeDefinitionCreateBulk{config: c.config, builders: builders}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
|
||||||
|
// a builder and applies setFunc on it.
|
||||||
|
func (c *UserAttributeDefinitionClient) MapCreateBulk(slice any, setFunc func(*UserAttributeDefinitionCreate, int)) *UserAttributeDefinitionCreateBulk {
|
||||||
|
rv := reflect.ValueOf(slice)
|
||||||
|
if rv.Kind() != reflect.Slice {
|
||||||
|
return &UserAttributeDefinitionCreateBulk{err: fmt.Errorf("calling to UserAttributeDefinitionClient.MapCreateBulk with wrong type %T, need slice", slice)}
|
||||||
|
}
|
||||||
|
builders := make([]*UserAttributeDefinitionCreate, rv.Len())
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
builders[i] = c.Create()
|
||||||
|
setFunc(builders[i], i)
|
||||||
|
}
|
||||||
|
return &UserAttributeDefinitionCreateBulk{config: c.config, builders: builders}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update returns an update builder for UserAttributeDefinition.
|
||||||
|
func (c *UserAttributeDefinitionClient) Update() *UserAttributeDefinitionUpdate {
|
||||||
|
mutation := newUserAttributeDefinitionMutation(c.config, OpUpdate)
|
||||||
|
return &UserAttributeDefinitionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOne returns an update builder for the given entity.
|
||||||
|
func (c *UserAttributeDefinitionClient) UpdateOne(_m *UserAttributeDefinition) *UserAttributeDefinitionUpdateOne {
|
||||||
|
mutation := newUserAttributeDefinitionMutation(c.config, OpUpdateOne, withUserAttributeDefinition(_m))
|
||||||
|
return &UserAttributeDefinitionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOneID returns an update builder for the given id.
|
||||||
|
func (c *UserAttributeDefinitionClient) UpdateOneID(id int64) *UserAttributeDefinitionUpdateOne {
|
||||||
|
mutation := newUserAttributeDefinitionMutation(c.config, OpUpdateOne, withUserAttributeDefinitionID(id))
|
||||||
|
return &UserAttributeDefinitionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete returns a delete builder for UserAttributeDefinition.
|
||||||
|
func (c *UserAttributeDefinitionClient) Delete() *UserAttributeDefinitionDelete {
|
||||||
|
mutation := newUserAttributeDefinitionMutation(c.config, OpDelete)
|
||||||
|
return &UserAttributeDefinitionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOne returns a builder for deleting the given entity.
|
||||||
|
func (c *UserAttributeDefinitionClient) DeleteOne(_m *UserAttributeDefinition) *UserAttributeDefinitionDeleteOne {
|
||||||
|
return c.DeleteOneID(_m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOneID returns a builder for deleting the given entity by its id.
|
||||||
|
func (c *UserAttributeDefinitionClient) DeleteOneID(id int64) *UserAttributeDefinitionDeleteOne {
|
||||||
|
builder := c.Delete().Where(userattributedefinition.ID(id))
|
||||||
|
builder.mutation.id = &id
|
||||||
|
builder.mutation.op = OpDeleteOne
|
||||||
|
return &UserAttributeDefinitionDeleteOne{builder}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query returns a query builder for UserAttributeDefinition.
|
||||||
|
func (c *UserAttributeDefinitionClient) Query() *UserAttributeDefinitionQuery {
|
||||||
|
return &UserAttributeDefinitionQuery{
|
||||||
|
config: c.config,
|
||||||
|
ctx: &QueryContext{Type: TypeUserAttributeDefinition},
|
||||||
|
inters: c.Interceptors(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a UserAttributeDefinition entity by its id.
|
||||||
|
func (c *UserAttributeDefinitionClient) Get(ctx context.Context, id int64) (*UserAttributeDefinition, error) {
|
||||||
|
return c.Query().Where(userattributedefinition.ID(id)).Only(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetX is like Get, but panics if an error occurs.
|
||||||
|
func (c *UserAttributeDefinitionClient) GetX(ctx context.Context, id int64) *UserAttributeDefinition {
|
||||||
|
obj, err := c.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return obj
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryValues queries the values edge of a UserAttributeDefinition.
|
||||||
|
func (c *UserAttributeDefinitionClient) QueryValues(_m *UserAttributeDefinition) *UserAttributeValueQuery {
|
||||||
|
query := (&UserAttributeValueClient{config: c.config}).Query()
|
||||||
|
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||||
|
id := _m.ID
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(userattributedefinition.Table, userattributedefinition.FieldID, id),
|
||||||
|
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.O2M, false, userattributedefinition.ValuesTable, userattributedefinition.ValuesColumn),
|
||||||
|
)
|
||||||
|
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||||
|
return fromV, nil
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hooks returns the client hooks.
|
||||||
|
func (c *UserAttributeDefinitionClient) Hooks() []Hook {
|
||||||
|
hooks := c.hooks.UserAttributeDefinition
|
||||||
|
return append(hooks[:len(hooks):len(hooks)], userattributedefinition.Hooks[:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Interceptors returns the client interceptors.
|
||||||
|
func (c *UserAttributeDefinitionClient) Interceptors() []Interceptor {
|
||||||
|
inters := c.inters.UserAttributeDefinition
|
||||||
|
return append(inters[:len(inters):len(inters)], userattributedefinition.Interceptors[:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserAttributeDefinitionClient) mutate(ctx context.Context, m *UserAttributeDefinitionMutation) (Value, error) {
|
||||||
|
switch m.Op() {
|
||||||
|
case OpCreate:
|
||||||
|
return (&UserAttributeDefinitionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||||
|
case OpUpdate:
|
||||||
|
return (&UserAttributeDefinitionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||||
|
case OpUpdateOne:
|
||||||
|
return (&UserAttributeDefinitionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||||
|
case OpDelete, OpDeleteOne:
|
||||||
|
return (&UserAttributeDefinitionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("ent: unknown UserAttributeDefinition mutation op: %q", m.Op())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueClient is a client for the UserAttributeValue schema.
|
||||||
|
type UserAttributeValueClient struct {
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserAttributeValueClient returns a client for the UserAttributeValue from the given config.
|
||||||
|
func NewUserAttributeValueClient(c config) *UserAttributeValueClient {
|
||||||
|
return &UserAttributeValueClient{config: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use adds a list of mutation hooks to the hooks stack.
|
||||||
|
// A call to `Use(f, g, h)` equals to `userattributevalue.Hooks(f(g(h())))`.
|
||||||
|
func (c *UserAttributeValueClient) Use(hooks ...Hook) {
|
||||||
|
c.hooks.UserAttributeValue = append(c.hooks.UserAttributeValue, hooks...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intercept adds a list of query interceptors to the interceptors stack.
|
||||||
|
// A call to `Intercept(f, g, h)` equals to `userattributevalue.Intercept(f(g(h())))`.
|
||||||
|
func (c *UserAttributeValueClient) Intercept(interceptors ...Interceptor) {
|
||||||
|
c.inters.UserAttributeValue = append(c.inters.UserAttributeValue, interceptors...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create returns a builder for creating a UserAttributeValue entity.
|
||||||
|
func (c *UserAttributeValueClient) Create() *UserAttributeValueCreate {
|
||||||
|
mutation := newUserAttributeValueMutation(c.config, OpCreate)
|
||||||
|
return &UserAttributeValueCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateBulk returns a builder for creating a bulk of UserAttributeValue entities.
|
||||||
|
func (c *UserAttributeValueClient) CreateBulk(builders ...*UserAttributeValueCreate) *UserAttributeValueCreateBulk {
|
||||||
|
return &UserAttributeValueCreateBulk{config: c.config, builders: builders}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
|
||||||
|
// a builder and applies setFunc on it.
|
||||||
|
func (c *UserAttributeValueClient) MapCreateBulk(slice any, setFunc func(*UserAttributeValueCreate, int)) *UserAttributeValueCreateBulk {
|
||||||
|
rv := reflect.ValueOf(slice)
|
||||||
|
if rv.Kind() != reflect.Slice {
|
||||||
|
return &UserAttributeValueCreateBulk{err: fmt.Errorf("calling to UserAttributeValueClient.MapCreateBulk with wrong type %T, need slice", slice)}
|
||||||
|
}
|
||||||
|
builders := make([]*UserAttributeValueCreate, rv.Len())
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
builders[i] = c.Create()
|
||||||
|
setFunc(builders[i], i)
|
||||||
|
}
|
||||||
|
return &UserAttributeValueCreateBulk{config: c.config, builders: builders}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update returns an update builder for UserAttributeValue.
|
||||||
|
func (c *UserAttributeValueClient) Update() *UserAttributeValueUpdate {
|
||||||
|
mutation := newUserAttributeValueMutation(c.config, OpUpdate)
|
||||||
|
return &UserAttributeValueUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOne returns an update builder for the given entity.
|
||||||
|
func (c *UserAttributeValueClient) UpdateOne(_m *UserAttributeValue) *UserAttributeValueUpdateOne {
|
||||||
|
mutation := newUserAttributeValueMutation(c.config, OpUpdateOne, withUserAttributeValue(_m))
|
||||||
|
return &UserAttributeValueUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOneID returns an update builder for the given id.
|
||||||
|
func (c *UserAttributeValueClient) UpdateOneID(id int64) *UserAttributeValueUpdateOne {
|
||||||
|
mutation := newUserAttributeValueMutation(c.config, OpUpdateOne, withUserAttributeValueID(id))
|
||||||
|
return &UserAttributeValueUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete returns a delete builder for UserAttributeValue.
|
||||||
|
func (c *UserAttributeValueClient) Delete() *UserAttributeValueDelete {
|
||||||
|
mutation := newUserAttributeValueMutation(c.config, OpDelete)
|
||||||
|
return &UserAttributeValueDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOne returns a builder for deleting the given entity.
|
||||||
|
func (c *UserAttributeValueClient) DeleteOne(_m *UserAttributeValue) *UserAttributeValueDeleteOne {
|
||||||
|
return c.DeleteOneID(_m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOneID returns a builder for deleting the given entity by its id.
|
||||||
|
func (c *UserAttributeValueClient) DeleteOneID(id int64) *UserAttributeValueDeleteOne {
|
||||||
|
builder := c.Delete().Where(userattributevalue.ID(id))
|
||||||
|
builder.mutation.id = &id
|
||||||
|
builder.mutation.op = OpDeleteOne
|
||||||
|
return &UserAttributeValueDeleteOne{builder}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query returns a query builder for UserAttributeValue.
|
||||||
|
func (c *UserAttributeValueClient) Query() *UserAttributeValueQuery {
|
||||||
|
return &UserAttributeValueQuery{
|
||||||
|
config: c.config,
|
||||||
|
ctx: &QueryContext{Type: TypeUserAttributeValue},
|
||||||
|
inters: c.Interceptors(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a UserAttributeValue entity by its id.
|
||||||
|
func (c *UserAttributeValueClient) Get(ctx context.Context, id int64) (*UserAttributeValue, error) {
|
||||||
|
return c.Query().Where(userattributevalue.ID(id)).Only(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetX is like Get, but panics if an error occurs.
|
||||||
|
func (c *UserAttributeValueClient) GetX(ctx context.Context, id int64) *UserAttributeValue {
|
||||||
|
obj, err := c.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return obj
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryUser queries the user edge of a UserAttributeValue.
|
||||||
|
func (c *UserAttributeValueClient) QueryUser(_m *UserAttributeValue) *UserQuery {
|
||||||
|
query := (&UserClient{config: c.config}).Query()
|
||||||
|
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||||
|
id := _m.ID
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, id),
|
||||||
|
sqlgraph.To(user.Table, user.FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.UserTable, userattributevalue.UserColumn),
|
||||||
|
)
|
||||||
|
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||||
|
return fromV, nil
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryDefinition queries the definition edge of a UserAttributeValue.
|
||||||
|
func (c *UserAttributeValueClient) QueryDefinition(_m *UserAttributeValue) *UserAttributeDefinitionQuery {
|
||||||
|
query := (&UserAttributeDefinitionClient{config: c.config}).Query()
|
||||||
|
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||||
|
id := _m.ID
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, id),
|
||||||
|
sqlgraph.To(userattributedefinition.Table, userattributedefinition.FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.DefinitionTable, userattributevalue.DefinitionColumn),
|
||||||
|
)
|
||||||
|
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||||
|
return fromV, nil
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hooks returns the client hooks.
|
||||||
|
func (c *UserAttributeValueClient) Hooks() []Hook {
|
||||||
|
return c.hooks.UserAttributeValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Interceptors returns the client interceptors.
|
||||||
|
func (c *UserAttributeValueClient) Interceptors() []Interceptor {
|
||||||
|
return c.inters.UserAttributeValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserAttributeValueClient) mutate(ctx context.Context, m *UserAttributeValueMutation) (Value, error) {
|
||||||
|
switch m.Op() {
|
||||||
|
case OpCreate:
|
||||||
|
return (&UserAttributeValueCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||||
|
case OpUpdate:
|
||||||
|
return (&UserAttributeValueUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||||
|
case OpUpdateOne:
|
||||||
|
return (&UserAttributeValueUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||||
|
case OpDelete, OpDeleteOne:
|
||||||
|
return (&UserAttributeValueDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("ent: unknown UserAttributeValue mutation op: %q", m.Op())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// UserSubscriptionClient is a client for the UserSubscription schema.
|
// UserSubscriptionClient is a client for the UserSubscription schema.
|
||||||
type UserSubscriptionClient struct {
|
type UserSubscriptionClient struct {
|
||||||
config
|
config
|
||||||
@@ -2278,11 +2628,13 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
|
|||||||
type (
|
type (
|
||||||
hooks struct {
|
hooks struct {
|
||||||
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog,
|
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog,
|
||||||
User, UserAllowedGroup, UserSubscription []ent.Hook
|
User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
|
||||||
|
UserSubscription []ent.Hook
|
||||||
}
|
}
|
||||||
inters struct {
|
inters struct {
|
||||||
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog,
|
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog,
|
||||||
User, UserAllowedGroup, UserSubscription []ent.Interceptor
|
User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
|
||||||
|
UserSubscription []ent.Interceptor
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,17 +85,19 @@ var (
|
|||||||
func checkColumn(t, c string) error {
|
func checkColumn(t, c string) error {
|
||||||
initCheck.Do(func() {
|
initCheck.Do(func() {
|
||||||
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
|
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
|
||||||
account.Table: account.ValidColumn,
|
account.Table: account.ValidColumn,
|
||||||
accountgroup.Table: accountgroup.ValidColumn,
|
accountgroup.Table: accountgroup.ValidColumn,
|
||||||
apikey.Table: apikey.ValidColumn,
|
apikey.Table: apikey.ValidColumn,
|
||||||
group.Table: group.ValidColumn,
|
group.Table: group.ValidColumn,
|
||||||
proxy.Table: proxy.ValidColumn,
|
proxy.Table: proxy.ValidColumn,
|
||||||
redeemcode.Table: redeemcode.ValidColumn,
|
redeemcode.Table: redeemcode.ValidColumn,
|
||||||
setting.Table: setting.ValidColumn,
|
setting.Table: setting.ValidColumn,
|
||||||
usagelog.Table: usagelog.ValidColumn,
|
usagelog.Table: usagelog.ValidColumn,
|
||||||
user.Table: user.ValidColumn,
|
user.Table: user.ValidColumn,
|
||||||
userallowedgroup.Table: userallowedgroup.ValidColumn,
|
userallowedgroup.Table: userallowedgroup.ValidColumn,
|
||||||
usersubscription.Table: usersubscription.ValidColumn,
|
userattributedefinition.Table: userattributedefinition.ValidColumn,
|
||||||
|
userattributevalue.Table: userattributevalue.ValidColumn,
|
||||||
|
usersubscription.Table: usersubscription.ValidColumn,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
return columnCheck(t, c)
|
return columnCheck(t, c)
|
||||||
|
|||||||
@@ -129,6 +129,30 @@ func (f UserAllowedGroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
|
|||||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAllowedGroupMutation", m)
|
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAllowedGroupMutation", m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The UserAttributeDefinitionFunc type is an adapter to allow the use of ordinary
|
||||||
|
// function as UserAttributeDefinition mutator.
|
||||||
|
type UserAttributeDefinitionFunc func(context.Context, *ent.UserAttributeDefinitionMutation) (ent.Value, error)
|
||||||
|
|
||||||
|
// Mutate calls f(ctx, m).
|
||||||
|
func (f UserAttributeDefinitionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
||||||
|
if mv, ok := m.(*ent.UserAttributeDefinitionMutation); ok {
|
||||||
|
return f(ctx, mv)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeDefinitionMutation", m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The UserAttributeValueFunc type is an adapter to allow the use of ordinary
|
||||||
|
// function as UserAttributeValue mutator.
|
||||||
|
type UserAttributeValueFunc func(context.Context, *ent.UserAttributeValueMutation) (ent.Value, error)
|
||||||
|
|
||||||
|
// Mutate calls f(ctx, m).
|
||||||
|
func (f UserAttributeValueFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
||||||
|
if mv, ok := m.(*ent.UserAttributeValueMutation); ok {
|
||||||
|
return f(ctx, mv)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeValueMutation", m)
|
||||||
|
}
|
||||||
|
|
||||||
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary
|
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary
|
||||||
// function as UserSubscription mutator.
|
// function as UserSubscription mutator.
|
||||||
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error)
|
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error)
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -348,6 +350,60 @@ func (f TraverseUserAllowedGroup) Traverse(ctx context.Context, q ent.Query) err
|
|||||||
return fmt.Errorf("unexpected query type %T. expect *ent.UserAllowedGroupQuery", q)
|
return fmt.Errorf("unexpected query type %T. expect *ent.UserAllowedGroupQuery", q)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The UserAttributeDefinitionFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||||
|
type UserAttributeDefinitionFunc func(context.Context, *ent.UserAttributeDefinitionQuery) (ent.Value, error)
|
||||||
|
|
||||||
|
// Query calls f(ctx, q).
|
||||||
|
func (f UserAttributeDefinitionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
|
||||||
|
if q, ok := q.(*ent.UserAttributeDefinitionQuery); ok {
|
||||||
|
return f(ctx, q)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeDefinitionQuery", q)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The TraverseUserAttributeDefinition type is an adapter to allow the use of ordinary function as Traverser.
|
||||||
|
type TraverseUserAttributeDefinition func(context.Context, *ent.UserAttributeDefinitionQuery) error
|
||||||
|
|
||||||
|
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
|
||||||
|
func (f TraverseUserAttributeDefinition) Intercept(next ent.Querier) ent.Querier {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse calls f(ctx, q).
|
||||||
|
func (f TraverseUserAttributeDefinition) Traverse(ctx context.Context, q ent.Query) error {
|
||||||
|
if q, ok := q.(*ent.UserAttributeDefinitionQuery); ok {
|
||||||
|
return f(ctx, q)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeDefinitionQuery", q)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The UserAttributeValueFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||||
|
type UserAttributeValueFunc func(context.Context, *ent.UserAttributeValueQuery) (ent.Value, error)
|
||||||
|
|
||||||
|
// Query calls f(ctx, q).
|
||||||
|
func (f UserAttributeValueFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
|
||||||
|
if q, ok := q.(*ent.UserAttributeValueQuery); ok {
|
||||||
|
return f(ctx, q)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The TraverseUserAttributeValue type is an adapter to allow the use of ordinary function as Traverser.
|
||||||
|
type TraverseUserAttributeValue func(context.Context, *ent.UserAttributeValueQuery) error
|
||||||
|
|
||||||
|
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
|
||||||
|
func (f TraverseUserAttributeValue) Intercept(next ent.Querier) ent.Querier {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse calls f(ctx, q).
|
||||||
|
func (f TraverseUserAttributeValue) Traverse(ctx context.Context, q ent.Query) error {
|
||||||
|
if q, ok := q.(*ent.UserAttributeValueQuery); ok {
|
||||||
|
return f(ctx, q)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q)
|
||||||
|
}
|
||||||
|
|
||||||
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier.
|
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||||
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error)
|
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error)
|
||||||
|
|
||||||
@@ -398,6 +454,10 @@ func NewQuery(q ent.Query) (Query, error) {
|
|||||||
return &query[*ent.UserQuery, predicate.User, user.OrderOption]{typ: ent.TypeUser, tq: q}, nil
|
return &query[*ent.UserQuery, predicate.User, user.OrderOption]{typ: ent.TypeUser, tq: q}, nil
|
||||||
case *ent.UserAllowedGroupQuery:
|
case *ent.UserAllowedGroupQuery:
|
||||||
return &query[*ent.UserAllowedGroupQuery, predicate.UserAllowedGroup, userallowedgroup.OrderOption]{typ: ent.TypeUserAllowedGroup, tq: q}, nil
|
return &query[*ent.UserAllowedGroupQuery, predicate.UserAllowedGroup, userallowedgroup.OrderOption]{typ: ent.TypeUserAllowedGroup, tq: q}, nil
|
||||||
|
case *ent.UserAttributeDefinitionQuery:
|
||||||
|
return &query[*ent.UserAttributeDefinitionQuery, predicate.UserAttributeDefinition, userattributedefinition.OrderOption]{typ: ent.TypeUserAttributeDefinition, tq: q}, nil
|
||||||
|
case *ent.UserAttributeValueQuery:
|
||||||
|
return &query[*ent.UserAttributeValueQuery, predicate.UserAttributeValue, userattributevalue.OrderOption]{typ: ent.TypeUserAttributeValue, tq: q}, nil
|
||||||
case *ent.UserSubscriptionQuery:
|
case *ent.UserSubscriptionQuery:
|
||||||
return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil
|
return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -477,7 +477,6 @@ var (
|
|||||||
{Name: "concurrency", Type: field.TypeInt, Default: 5},
|
{Name: "concurrency", Type: field.TypeInt, Default: 5},
|
||||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||||
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
|
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
|
||||||
{Name: "wechat", Type: field.TypeString, Size: 100, Default: ""},
|
|
||||||
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
|
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
|
||||||
}
|
}
|
||||||
// UsersTable holds the schema information for the "users" table.
|
// UsersTable holds the schema information for the "users" table.
|
||||||
@@ -531,6 +530,92 @@ var (
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
// UserAttributeDefinitionsColumns holds the columns for the "user_attribute_definitions" table.
|
||||||
|
UserAttributeDefinitionsColumns = []*schema.Column{
|
||||||
|
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||||
|
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
|
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
|
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
|
{Name: "key", Type: field.TypeString, Size: 100},
|
||||||
|
{Name: "name", Type: field.TypeString, Size: 255},
|
||||||
|
{Name: "description", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
|
||||||
|
{Name: "type", Type: field.TypeString, Size: 20},
|
||||||
|
{Name: "options", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||||
|
{Name: "required", Type: field.TypeBool, Default: false},
|
||||||
|
{Name: "validation", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||||
|
{Name: "placeholder", Type: field.TypeString, Size: 255, Default: ""},
|
||||||
|
{Name: "display_order", Type: field.TypeInt, Default: 0},
|
||||||
|
{Name: "enabled", Type: field.TypeBool, Default: true},
|
||||||
|
}
|
||||||
|
// UserAttributeDefinitionsTable holds the schema information for the "user_attribute_definitions" table.
|
||||||
|
UserAttributeDefinitionsTable = &schema.Table{
|
||||||
|
Name: "user_attribute_definitions",
|
||||||
|
Columns: UserAttributeDefinitionsColumns,
|
||||||
|
PrimaryKey: []*schema.Column{UserAttributeDefinitionsColumns[0]},
|
||||||
|
Indexes: []*schema.Index{
|
||||||
|
{
|
||||||
|
Name: "userattributedefinition_key",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{UserAttributeDefinitionsColumns[4]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "userattributedefinition_enabled",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{UserAttributeDefinitionsColumns[13]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "userattributedefinition_display_order",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{UserAttributeDefinitionsColumns[12]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "userattributedefinition_deleted_at",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{UserAttributeDefinitionsColumns[3]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// UserAttributeValuesColumns holds the columns for the "user_attribute_values" table.
|
||||||
|
UserAttributeValuesColumns = []*schema.Column{
|
||||||
|
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||||
|
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
|
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
|
{Name: "value", Type: field.TypeString, Size: 2147483647, Default: ""},
|
||||||
|
{Name: "user_id", Type: field.TypeInt64},
|
||||||
|
{Name: "attribute_id", Type: field.TypeInt64},
|
||||||
|
}
|
||||||
|
// UserAttributeValuesTable holds the schema information for the "user_attribute_values" table.
|
||||||
|
UserAttributeValuesTable = &schema.Table{
|
||||||
|
Name: "user_attribute_values",
|
||||||
|
Columns: UserAttributeValuesColumns,
|
||||||
|
PrimaryKey: []*schema.Column{UserAttributeValuesColumns[0]},
|
||||||
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
|
{
|
||||||
|
Symbol: "user_attribute_values_users_attribute_values",
|
||||||
|
Columns: []*schema.Column{UserAttributeValuesColumns[4]},
|
||||||
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
|
OnDelete: schema.NoAction,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Symbol: "user_attribute_values_user_attribute_definitions_values",
|
||||||
|
Columns: []*schema.Column{UserAttributeValuesColumns[5]},
|
||||||
|
RefColumns: []*schema.Column{UserAttributeDefinitionsColumns[0]},
|
||||||
|
OnDelete: schema.NoAction,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Indexes: []*schema.Index{
|
||||||
|
{
|
||||||
|
Name: "userattributevalue_user_id_attribute_id",
|
||||||
|
Unique: true,
|
||||||
|
Columns: []*schema.Column{UserAttributeValuesColumns[4], UserAttributeValuesColumns[5]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "userattributevalue_attribute_id",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{UserAttributeValuesColumns[5]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
// UserSubscriptionsColumns holds the columns for the "user_subscriptions" table.
|
// UserSubscriptionsColumns holds the columns for the "user_subscriptions" table.
|
||||||
UserSubscriptionsColumns = []*schema.Column{
|
UserSubscriptionsColumns = []*schema.Column{
|
||||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||||
@@ -627,6 +712,8 @@ var (
|
|||||||
UsageLogsTable,
|
UsageLogsTable,
|
||||||
UsersTable,
|
UsersTable,
|
||||||
UserAllowedGroupsTable,
|
UserAllowedGroupsTable,
|
||||||
|
UserAttributeDefinitionsTable,
|
||||||
|
UserAttributeValuesTable,
|
||||||
UserSubscriptionsTable,
|
UserSubscriptionsTable,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -676,6 +763,14 @@ func init() {
|
|||||||
UserAllowedGroupsTable.Annotation = &entsql.Annotation{
|
UserAllowedGroupsTable.Annotation = &entsql.Annotation{
|
||||||
Table: "user_allowed_groups",
|
Table: "user_allowed_groups",
|
||||||
}
|
}
|
||||||
|
UserAttributeDefinitionsTable.Annotation = &entsql.Annotation{
|
||||||
|
Table: "user_attribute_definitions",
|
||||||
|
}
|
||||||
|
UserAttributeValuesTable.ForeignKeys[0].RefTable = UsersTable
|
||||||
|
UserAttributeValuesTable.ForeignKeys[1].RefTable = UserAttributeDefinitionsTable
|
||||||
|
UserAttributeValuesTable.Annotation = &entsql.Annotation{
|
||||||
|
Table: "user_attribute_values",
|
||||||
|
}
|
||||||
UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable
|
UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable
|
||||||
UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable
|
UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable
|
||||||
UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable
|
UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -36,5 +36,11 @@ type User func(*sql.Selector)
|
|||||||
// UserAllowedGroup is the predicate function for userallowedgroup builders.
|
// UserAllowedGroup is the predicate function for userallowedgroup builders.
|
||||||
type UserAllowedGroup func(*sql.Selector)
|
type UserAllowedGroup func(*sql.Selector)
|
||||||
|
|
||||||
|
// UserAttributeDefinition is the predicate function for userattributedefinition builders.
|
||||||
|
type UserAttributeDefinition func(*sql.Selector)
|
||||||
|
|
||||||
|
// UserAttributeValue is the predicate function for userattributevalue builders.
|
||||||
|
type UserAttributeValue func(*sql.Selector)
|
||||||
|
|
||||||
// UserSubscription is the predicate function for usersubscription builders.
|
// UserSubscription is the predicate function for usersubscription builders.
|
||||||
type UserSubscription func(*sql.Selector)
|
type UserSubscription func(*sql.Selector)
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,14 +606,8 @@ func init() {
|
|||||||
user.DefaultUsername = userDescUsername.Default.(string)
|
user.DefaultUsername = userDescUsername.Default.(string)
|
||||||
// user.UsernameValidator is a validator for the "username" field. It is called by the builders before save.
|
// user.UsernameValidator is a validator for the "username" field. It is called by the builders before save.
|
||||||
user.UsernameValidator = userDescUsername.Validators[0].(func(string) error)
|
user.UsernameValidator = userDescUsername.Validators[0].(func(string) error)
|
||||||
// userDescWechat is the schema descriptor for wechat field.
|
|
||||||
userDescWechat := userFields[7].Descriptor()
|
|
||||||
// user.DefaultWechat holds the default value on creation for the wechat field.
|
|
||||||
user.DefaultWechat = userDescWechat.Default.(string)
|
|
||||||
// user.WechatValidator is a validator for the "wechat" field. It is called by the builders before save.
|
|
||||||
user.WechatValidator = userDescWechat.Validators[0].(func(string) error)
|
|
||||||
// userDescNotes is the schema descriptor for notes field.
|
// userDescNotes is the schema descriptor for notes field.
|
||||||
userDescNotes := userFields[8].Descriptor()
|
userDescNotes := userFields[7].Descriptor()
|
||||||
// user.DefaultNotes holds the default value on creation for the notes field.
|
// user.DefaultNotes holds the default value on creation for the notes field.
|
||||||
user.DefaultNotes = userDescNotes.Default.(string)
|
user.DefaultNotes = userDescNotes.Default.(string)
|
||||||
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
||||||
@@ -620,6 +616,128 @@ func init() {
|
|||||||
userallowedgroupDescCreatedAt := userallowedgroupFields[2].Descriptor()
|
userallowedgroupDescCreatedAt := userallowedgroupFields[2].Descriptor()
|
||||||
// userallowedgroup.DefaultCreatedAt holds the default value on creation for the created_at field.
|
// userallowedgroup.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
userallowedgroup.DefaultCreatedAt = userallowedgroupDescCreatedAt.Default.(func() time.Time)
|
userallowedgroup.DefaultCreatedAt = userallowedgroupDescCreatedAt.Default.(func() time.Time)
|
||||||
|
userattributedefinitionMixin := schema.UserAttributeDefinition{}.Mixin()
|
||||||
|
userattributedefinitionMixinHooks1 := userattributedefinitionMixin[1].Hooks()
|
||||||
|
userattributedefinition.Hooks[0] = userattributedefinitionMixinHooks1[0]
|
||||||
|
userattributedefinitionMixinInters1 := userattributedefinitionMixin[1].Interceptors()
|
||||||
|
userattributedefinition.Interceptors[0] = userattributedefinitionMixinInters1[0]
|
||||||
|
userattributedefinitionMixinFields0 := userattributedefinitionMixin[0].Fields()
|
||||||
|
_ = userattributedefinitionMixinFields0
|
||||||
|
userattributedefinitionFields := schema.UserAttributeDefinition{}.Fields()
|
||||||
|
_ = userattributedefinitionFields
|
||||||
|
// userattributedefinitionDescCreatedAt is the schema descriptor for created_at field.
|
||||||
|
userattributedefinitionDescCreatedAt := userattributedefinitionMixinFields0[0].Descriptor()
|
||||||
|
// userattributedefinition.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
|
userattributedefinition.DefaultCreatedAt = userattributedefinitionDescCreatedAt.Default.(func() time.Time)
|
||||||
|
// userattributedefinitionDescUpdatedAt is the schema descriptor for updated_at field.
|
||||||
|
userattributedefinitionDescUpdatedAt := userattributedefinitionMixinFields0[1].Descriptor()
|
||||||
|
// userattributedefinition.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||||
|
userattributedefinition.DefaultUpdatedAt = userattributedefinitionDescUpdatedAt.Default.(func() time.Time)
|
||||||
|
// userattributedefinition.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||||
|
userattributedefinition.UpdateDefaultUpdatedAt = userattributedefinitionDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||||
|
// userattributedefinitionDescKey is the schema descriptor for key field.
|
||||||
|
userattributedefinitionDescKey := userattributedefinitionFields[0].Descriptor()
|
||||||
|
// userattributedefinition.KeyValidator is a validator for the "key" field. It is called by the builders before save.
|
||||||
|
userattributedefinition.KeyValidator = func() func(string) error {
|
||||||
|
validators := userattributedefinitionDescKey.Validators
|
||||||
|
fns := [...]func(string) error{
|
||||||
|
validators[0].(func(string) error),
|
||||||
|
validators[1].(func(string) error),
|
||||||
|
}
|
||||||
|
return func(key string) error {
|
||||||
|
for _, fn := range fns {
|
||||||
|
if err := fn(key); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
// userattributedefinitionDescName is the schema descriptor for name field.
|
||||||
|
userattributedefinitionDescName := userattributedefinitionFields[1].Descriptor()
|
||||||
|
// userattributedefinition.NameValidator is a validator for the "name" field. It is called by the builders before save.
|
||||||
|
userattributedefinition.NameValidator = func() func(string) error {
|
||||||
|
validators := userattributedefinitionDescName.Validators
|
||||||
|
fns := [...]func(string) error{
|
||||||
|
validators[0].(func(string) error),
|
||||||
|
validators[1].(func(string) error),
|
||||||
|
}
|
||||||
|
return func(name string) error {
|
||||||
|
for _, fn := range fns {
|
||||||
|
if err := fn(name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
// userattributedefinitionDescDescription is the schema descriptor for description field.
|
||||||
|
userattributedefinitionDescDescription := userattributedefinitionFields[2].Descriptor()
|
||||||
|
// userattributedefinition.DefaultDescription holds the default value on creation for the description field.
|
||||||
|
userattributedefinition.DefaultDescription = userattributedefinitionDescDescription.Default.(string)
|
||||||
|
// userattributedefinitionDescType is the schema descriptor for type field.
|
||||||
|
userattributedefinitionDescType := userattributedefinitionFields[3].Descriptor()
|
||||||
|
// userattributedefinition.TypeValidator is a validator for the "type" field. It is called by the builders before save.
|
||||||
|
userattributedefinition.TypeValidator = func() func(string) error {
|
||||||
|
validators := userattributedefinitionDescType.Validators
|
||||||
|
fns := [...]func(string) error{
|
||||||
|
validators[0].(func(string) error),
|
||||||
|
validators[1].(func(string) error),
|
||||||
|
}
|
||||||
|
return func(_type string) error {
|
||||||
|
for _, fn := range fns {
|
||||||
|
if err := fn(_type); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
// userattributedefinitionDescOptions is the schema descriptor for options field.
|
||||||
|
userattributedefinitionDescOptions := userattributedefinitionFields[4].Descriptor()
|
||||||
|
// userattributedefinition.DefaultOptions holds the default value on creation for the options field.
|
||||||
|
userattributedefinition.DefaultOptions = userattributedefinitionDescOptions.Default.([]map[string]interface{})
|
||||||
|
// userattributedefinitionDescRequired is the schema descriptor for required field.
|
||||||
|
userattributedefinitionDescRequired := userattributedefinitionFields[5].Descriptor()
|
||||||
|
// userattributedefinition.DefaultRequired holds the default value on creation for the required field.
|
||||||
|
userattributedefinition.DefaultRequired = userattributedefinitionDescRequired.Default.(bool)
|
||||||
|
// userattributedefinitionDescValidation is the schema descriptor for validation field.
|
||||||
|
userattributedefinitionDescValidation := userattributedefinitionFields[6].Descriptor()
|
||||||
|
// userattributedefinition.DefaultValidation holds the default value on creation for the validation field.
|
||||||
|
userattributedefinition.DefaultValidation = userattributedefinitionDescValidation.Default.(map[string]interface{})
|
||||||
|
// userattributedefinitionDescPlaceholder is the schema descriptor for placeholder field.
|
||||||
|
userattributedefinitionDescPlaceholder := userattributedefinitionFields[7].Descriptor()
|
||||||
|
// userattributedefinition.DefaultPlaceholder holds the default value on creation for the placeholder field.
|
||||||
|
userattributedefinition.DefaultPlaceholder = userattributedefinitionDescPlaceholder.Default.(string)
|
||||||
|
// userattributedefinition.PlaceholderValidator is a validator for the "placeholder" field. It is called by the builders before save.
|
||||||
|
userattributedefinition.PlaceholderValidator = userattributedefinitionDescPlaceholder.Validators[0].(func(string) error)
|
||||||
|
// userattributedefinitionDescDisplayOrder is the schema descriptor for display_order field.
|
||||||
|
userattributedefinitionDescDisplayOrder := userattributedefinitionFields[8].Descriptor()
|
||||||
|
// userattributedefinition.DefaultDisplayOrder holds the default value on creation for the display_order field.
|
||||||
|
userattributedefinition.DefaultDisplayOrder = userattributedefinitionDescDisplayOrder.Default.(int)
|
||||||
|
// userattributedefinitionDescEnabled is the schema descriptor for enabled field.
|
||||||
|
userattributedefinitionDescEnabled := userattributedefinitionFields[9].Descriptor()
|
||||||
|
// userattributedefinition.DefaultEnabled holds the default value on creation for the enabled field.
|
||||||
|
userattributedefinition.DefaultEnabled = userattributedefinitionDescEnabled.Default.(bool)
|
||||||
|
userattributevalueMixin := schema.UserAttributeValue{}.Mixin()
|
||||||
|
userattributevalueMixinFields0 := userattributevalueMixin[0].Fields()
|
||||||
|
_ = userattributevalueMixinFields0
|
||||||
|
userattributevalueFields := schema.UserAttributeValue{}.Fields()
|
||||||
|
_ = userattributevalueFields
|
||||||
|
// userattributevalueDescCreatedAt is the schema descriptor for created_at field.
|
||||||
|
userattributevalueDescCreatedAt := userattributevalueMixinFields0[0].Descriptor()
|
||||||
|
// userattributevalue.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
|
userattributevalue.DefaultCreatedAt = userattributevalueDescCreatedAt.Default.(func() time.Time)
|
||||||
|
// userattributevalueDescUpdatedAt is the schema descriptor for updated_at field.
|
||||||
|
userattributevalueDescUpdatedAt := userattributevalueMixinFields0[1].Descriptor()
|
||||||
|
// userattributevalue.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||||
|
userattributevalue.DefaultUpdatedAt = userattributevalueDescUpdatedAt.Default.(func() time.Time)
|
||||||
|
// userattributevalue.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||||
|
userattributevalue.UpdateDefaultUpdatedAt = userattributevalueDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||||
|
// userattributevalueDescValue is the schema descriptor for value field.
|
||||||
|
userattributevalueDescValue := userattributevalueFields[2].Descriptor()
|
||||||
|
// userattributevalue.DefaultValue holds the default value on creation for the value field.
|
||||||
|
userattributevalue.DefaultValue = userattributevalueDescValue.Default.(string)
|
||||||
usersubscriptionMixin := schema.UserSubscription{}.Mixin()
|
usersubscriptionMixin := schema.UserSubscription{}.Mixin()
|
||||||
usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks()
|
usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks()
|
||||||
usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0]
|
usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0]
|
||||||
|
|||||||
@@ -57,9 +57,7 @@ func (User) Fields() []ent.Field {
|
|||||||
field.String("username").
|
field.String("username").
|
||||||
MaxLen(100).
|
MaxLen(100).
|
||||||
Default(""),
|
Default(""),
|
||||||
field.String("wechat").
|
// wechat field migrated to user_attribute_values (see migration 019)
|
||||||
MaxLen(100).
|
|
||||||
Default(""),
|
|
||||||
field.String("notes").
|
field.String("notes").
|
||||||
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
||||||
Default(""),
|
Default(""),
|
||||||
@@ -75,6 +73,7 @@ func (User) Edges() []ent.Edge {
|
|||||||
edge.To("allowed_groups", Group.Type).
|
edge.To("allowed_groups", Group.Type).
|
||||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||||
edge.To("usage_logs", UsageLog.Type),
|
edge.To("usage_logs", UsageLog.Type),
|
||||||
|
edge.To("attribute_values", UserAttributeValue.Type),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
109
backend/ent/schema/user_attribute_definition.go
Normal file
109
backend/ent/schema/user_attribute_definition.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||||
|
|
||||||
|
"entgo.io/ent"
|
||||||
|
"entgo.io/ent/dialect"
|
||||||
|
"entgo.io/ent/dialect/entsql"
|
||||||
|
"entgo.io/ent/schema"
|
||||||
|
"entgo.io/ent/schema/edge"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"entgo.io/ent/schema/index"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeDefinition holds the schema definition for custom user attributes.
|
||||||
|
//
|
||||||
|
// This entity defines the metadata for user attributes, such as:
|
||||||
|
// - Attribute key (unique identifier like "company_name")
|
||||||
|
// - Display name shown in forms
|
||||||
|
// - Field type (text, number, select, etc.)
|
||||||
|
// - Validation rules
|
||||||
|
// - Whether the field is required or enabled
|
||||||
|
type UserAttributeDefinition struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttributeDefinition) Annotations() []schema.Annotation {
|
||||||
|
return []schema.Annotation{
|
||||||
|
entsql.Annotation{Table: "user_attribute_definitions"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttributeDefinition) Mixin() []ent.Mixin {
|
||||||
|
return []ent.Mixin{
|
||||||
|
mixins.TimeMixin{},
|
||||||
|
mixins.SoftDeleteMixin{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttributeDefinition) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
// key: Unique identifier for the attribute (e.g., "company_name")
|
||||||
|
// Used for programmatic reference
|
||||||
|
field.String("key").
|
||||||
|
MaxLen(100).
|
||||||
|
NotEmpty(),
|
||||||
|
|
||||||
|
// name: Display name shown in forms (e.g., "Company Name")
|
||||||
|
field.String("name").
|
||||||
|
MaxLen(255).
|
||||||
|
NotEmpty(),
|
||||||
|
|
||||||
|
// description: Optional description/help text for the attribute
|
||||||
|
field.String("description").
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
||||||
|
Default(""),
|
||||||
|
|
||||||
|
// type: Attribute type - text, textarea, number, email, url, date, select, multi_select
|
||||||
|
field.String("type").
|
||||||
|
MaxLen(20).
|
||||||
|
NotEmpty(),
|
||||||
|
|
||||||
|
// options: Select options for select/multi_select types (stored as JSONB)
|
||||||
|
// Format: [{"value": "xxx", "label": "XXX"}, ...]
|
||||||
|
field.JSON("options", []map[string]any{}).
|
||||||
|
Default([]map[string]any{}).
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||||
|
|
||||||
|
// required: Whether this attribute is required when editing a user
|
||||||
|
field.Bool("required").
|
||||||
|
Default(false),
|
||||||
|
|
||||||
|
// validation: Validation rules for the attribute value (stored as JSONB)
|
||||||
|
// Format: {"min_length": 1, "max_length": 100, "min": 0, "max": 100, "pattern": "^[a-z]+$", "message": "..."}
|
||||||
|
field.JSON("validation", map[string]any{}).
|
||||||
|
Default(map[string]any{}).
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||||
|
|
||||||
|
// placeholder: Placeholder text shown in input fields
|
||||||
|
field.String("placeholder").
|
||||||
|
MaxLen(255).
|
||||||
|
Default(""),
|
||||||
|
|
||||||
|
// display_order: Order in which attributes are displayed (lower = first)
|
||||||
|
field.Int("display_order").
|
||||||
|
Default(0),
|
||||||
|
|
||||||
|
// enabled: Whether this attribute is active and shown in forms
|
||||||
|
field.Bool("enabled").
|
||||||
|
Default(true),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttributeDefinition) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{
|
||||||
|
// values: All user values for this attribute definition
|
||||||
|
edge.To("values", UserAttributeValue.Type),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttributeDefinition) Indexes() []ent.Index {
|
||||||
|
return []ent.Index{
|
||||||
|
// Partial unique index on key (WHERE deleted_at IS NULL) via migration
|
||||||
|
index.Fields("key"),
|
||||||
|
index.Fields("enabled"),
|
||||||
|
index.Fields("display_order"),
|
||||||
|
index.Fields("deleted_at"),
|
||||||
|
}
|
||||||
|
}
|
||||||
74
backend/ent/schema/user_attribute_value.go
Normal file
74
backend/ent/schema/user_attribute_value.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||||
|
|
||||||
|
"entgo.io/ent"
|
||||||
|
"entgo.io/ent/dialect/entsql"
|
||||||
|
"entgo.io/ent/schema"
|
||||||
|
"entgo.io/ent/schema/edge"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"entgo.io/ent/schema/index"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeValue holds a user's value for a specific attribute.
|
||||||
|
//
|
||||||
|
// This entity stores the actual values that users have for each attribute definition.
|
||||||
|
// Values are stored as strings and converted to the appropriate type by the application.
|
||||||
|
type UserAttributeValue struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttributeValue) Annotations() []schema.Annotation {
|
||||||
|
return []schema.Annotation{
|
||||||
|
entsql.Annotation{Table: "user_attribute_values"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttributeValue) Mixin() []ent.Mixin {
|
||||||
|
return []ent.Mixin{
|
||||||
|
// Only use TimeMixin, no soft delete - values are hard deleted
|
||||||
|
mixins.TimeMixin{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttributeValue) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
// user_id: References the user this value belongs to
|
||||||
|
field.Int64("user_id"),
|
||||||
|
|
||||||
|
// attribute_id: References the attribute definition
|
||||||
|
field.Int64("attribute_id"),
|
||||||
|
|
||||||
|
// value: The actual value stored as a string
|
||||||
|
// For multi_select, this is a JSON array string
|
||||||
|
field.Text("value").
|
||||||
|
Default(""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttributeValue) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{
|
||||||
|
// user: The user who owns this attribute value
|
||||||
|
edge.From("user", User.Type).
|
||||||
|
Ref("attribute_values").
|
||||||
|
Field("user_id").
|
||||||
|
Required().
|
||||||
|
Unique(),
|
||||||
|
|
||||||
|
// definition: The attribute definition this value is for
|
||||||
|
edge.From("definition", UserAttributeDefinition.Type).
|
||||||
|
Ref("values").
|
||||||
|
Field("attribute_id").
|
||||||
|
Required().
|
||||||
|
Unique(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttributeValue) Indexes() []ent.Index {
|
||||||
|
return []ent.Index{
|
||||||
|
// Unique index on (user_id, attribute_id)
|
||||||
|
index.Fields("user_id", "attribute_id").Unique(),
|
||||||
|
index.Fields("attribute_id"),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -34,6 +34,10 @@ type Tx struct {
|
|||||||
User *UserClient
|
User *UserClient
|
||||||
// UserAllowedGroup is the client for interacting with the UserAllowedGroup builders.
|
// UserAllowedGroup is the client for interacting with the UserAllowedGroup builders.
|
||||||
UserAllowedGroup *UserAllowedGroupClient
|
UserAllowedGroup *UserAllowedGroupClient
|
||||||
|
// UserAttributeDefinition is the client for interacting with the UserAttributeDefinition builders.
|
||||||
|
UserAttributeDefinition *UserAttributeDefinitionClient
|
||||||
|
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
|
||||||
|
UserAttributeValue *UserAttributeValueClient
|
||||||
// UserSubscription is the client for interacting with the UserSubscription builders.
|
// UserSubscription is the client for interacting with the UserSubscription builders.
|
||||||
UserSubscription *UserSubscriptionClient
|
UserSubscription *UserSubscriptionClient
|
||||||
|
|
||||||
@@ -177,6 +181,8 @@ func (tx *Tx) init() {
|
|||||||
tx.UsageLog = NewUsageLogClient(tx.config)
|
tx.UsageLog = NewUsageLogClient(tx.config)
|
||||||
tx.User = NewUserClient(tx.config)
|
tx.User = NewUserClient(tx.config)
|
||||||
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
|
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
|
||||||
|
tx.UserAttributeDefinition = NewUserAttributeDefinitionClient(tx.config)
|
||||||
|
tx.UserAttributeValue = NewUserAttributeValueClient(tx.config)
|
||||||
tx.UserSubscription = NewUserSubscriptionClient(tx.config)
|
tx.UserSubscription = NewUserSubscriptionClient(tx.config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,8 +37,6 @@ type User struct {
|
|||||||
Status string `json:"status,omitempty"`
|
Status string `json:"status,omitempty"`
|
||||||
// Username holds the value of the "username" field.
|
// Username holds the value of the "username" field.
|
||||||
Username string `json:"username,omitempty"`
|
Username string `json:"username,omitempty"`
|
||||||
// Wechat holds the value of the "wechat" field.
|
|
||||||
Wechat string `json:"wechat,omitempty"`
|
|
||||||
// Notes holds the value of the "notes" field.
|
// Notes holds the value of the "notes" field.
|
||||||
Notes string `json:"notes,omitempty"`
|
Notes string `json:"notes,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
@@ -61,11 +59,13 @@ type UserEdges struct {
|
|||||||
AllowedGroups []*Group `json:"allowed_groups,omitempty"`
|
AllowedGroups []*Group `json:"allowed_groups,omitempty"`
|
||||||
// UsageLogs holds the value of the usage_logs edge.
|
// UsageLogs holds the value of the usage_logs edge.
|
||||||
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
|
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
|
||||||
|
// AttributeValues holds the value of the attribute_values edge.
|
||||||
|
AttributeValues []*UserAttributeValue `json:"attribute_values,omitempty"`
|
||||||
// UserAllowedGroups holds the value of the user_allowed_groups edge.
|
// UserAllowedGroups holds the value of the user_allowed_groups edge.
|
||||||
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
|
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
|
||||||
// loadedTypes holds the information for reporting if a
|
// loadedTypes holds the information for reporting if a
|
||||||
// type was loaded (or requested) in eager-loading or not.
|
// type was loaded (or requested) in eager-loading or not.
|
||||||
loadedTypes [7]bool
|
loadedTypes [8]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeysOrErr returns the APIKeys value or an error if the edge
|
// APIKeysOrErr returns the APIKeys value or an error if the edge
|
||||||
@@ -122,10 +122,19 @@ func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
|
|||||||
return nil, &NotLoadedError{edge: "usage_logs"}
|
return nil, &NotLoadedError{edge: "usage_logs"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AttributeValuesOrErr returns the AttributeValues value or an error if the edge
|
||||||
|
// was not loaded in eager-loading.
|
||||||
|
func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
|
||||||
|
if e.loadedTypes[6] {
|
||||||
|
return e.AttributeValues, nil
|
||||||
|
}
|
||||||
|
return nil, &NotLoadedError{edge: "attribute_values"}
|
||||||
|
}
|
||||||
|
|
||||||
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
|
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
|
||||||
// was not loaded in eager-loading.
|
// was not loaded in eager-loading.
|
||||||
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
|
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
|
||||||
if e.loadedTypes[6] {
|
if e.loadedTypes[7] {
|
||||||
return e.UserAllowedGroups, nil
|
return e.UserAllowedGroups, nil
|
||||||
}
|
}
|
||||||
return nil, &NotLoadedError{edge: "user_allowed_groups"}
|
return nil, &NotLoadedError{edge: "user_allowed_groups"}
|
||||||
@@ -140,7 +149,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case user.FieldID, user.FieldConcurrency:
|
case user.FieldID, user.FieldConcurrency:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldWechat, user.FieldNotes:
|
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
|
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
|
||||||
values[i] = new(sql.NullTime)
|
values[i] = new(sql.NullTime)
|
||||||
@@ -226,12 +235,6 @@ func (_m *User) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.Username = value.String
|
_m.Username = value.String
|
||||||
}
|
}
|
||||||
case user.FieldWechat:
|
|
||||||
if value, ok := values[i].(*sql.NullString); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field wechat", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.Wechat = value.String
|
|
||||||
}
|
|
||||||
case user.FieldNotes:
|
case user.FieldNotes:
|
||||||
if value, ok := values[i].(*sql.NullString); !ok {
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field notes", values[i])
|
return fmt.Errorf("unexpected type %T for field notes", values[i])
|
||||||
@@ -281,6 +284,11 @@ func (_m *User) QueryUsageLogs() *UsageLogQuery {
|
|||||||
return NewUserClient(_m.config).QueryUsageLogs(_m)
|
return NewUserClient(_m.config).QueryUsageLogs(_m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryAttributeValues queries the "attribute_values" edge of the User entity.
|
||||||
|
func (_m *User) QueryAttributeValues() *UserAttributeValueQuery {
|
||||||
|
return NewUserClient(_m.config).QueryAttributeValues(_m)
|
||||||
|
}
|
||||||
|
|
||||||
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
|
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
|
||||||
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
|
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
|
||||||
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
|
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
|
||||||
@@ -341,9 +349,6 @@ func (_m *User) String() string {
|
|||||||
builder.WriteString("username=")
|
builder.WriteString("username=")
|
||||||
builder.WriteString(_m.Username)
|
builder.WriteString(_m.Username)
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
builder.WriteString("wechat=")
|
|
||||||
builder.WriteString(_m.Wechat)
|
|
||||||
builder.WriteString(", ")
|
|
||||||
builder.WriteString("notes=")
|
builder.WriteString("notes=")
|
||||||
builder.WriteString(_m.Notes)
|
builder.WriteString(_m.Notes)
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
|
|||||||
@@ -35,8 +35,6 @@ const (
|
|||||||
FieldStatus = "status"
|
FieldStatus = "status"
|
||||||
// FieldUsername holds the string denoting the username field in the database.
|
// FieldUsername holds the string denoting the username field in the database.
|
||||||
FieldUsername = "username"
|
FieldUsername = "username"
|
||||||
// FieldWechat holds the string denoting the wechat field in the database.
|
|
||||||
FieldWechat = "wechat"
|
|
||||||
// FieldNotes holds the string denoting the notes field in the database.
|
// FieldNotes holds the string denoting the notes field in the database.
|
||||||
FieldNotes = "notes"
|
FieldNotes = "notes"
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
@@ -51,6 +49,8 @@ const (
|
|||||||
EdgeAllowedGroups = "allowed_groups"
|
EdgeAllowedGroups = "allowed_groups"
|
||||||
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
|
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
|
||||||
EdgeUsageLogs = "usage_logs"
|
EdgeUsageLogs = "usage_logs"
|
||||||
|
// EdgeAttributeValues holds the string denoting the attribute_values edge name in mutations.
|
||||||
|
EdgeAttributeValues = "attribute_values"
|
||||||
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
|
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
|
||||||
EdgeUserAllowedGroups = "user_allowed_groups"
|
EdgeUserAllowedGroups = "user_allowed_groups"
|
||||||
// Table holds the table name of the user in the database.
|
// Table holds the table name of the user in the database.
|
||||||
@@ -95,6 +95,13 @@ const (
|
|||||||
UsageLogsInverseTable = "usage_logs"
|
UsageLogsInverseTable = "usage_logs"
|
||||||
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
|
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
|
||||||
UsageLogsColumn = "user_id"
|
UsageLogsColumn = "user_id"
|
||||||
|
// AttributeValuesTable is the table that holds the attribute_values relation/edge.
|
||||||
|
AttributeValuesTable = "user_attribute_values"
|
||||||
|
// AttributeValuesInverseTable is the table name for the UserAttributeValue entity.
|
||||||
|
// It exists in this package in order to avoid circular dependency with the "userattributevalue" package.
|
||||||
|
AttributeValuesInverseTable = "user_attribute_values"
|
||||||
|
// AttributeValuesColumn is the table column denoting the attribute_values relation/edge.
|
||||||
|
AttributeValuesColumn = "user_id"
|
||||||
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
|
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
|
||||||
UserAllowedGroupsTable = "user_allowed_groups"
|
UserAllowedGroupsTable = "user_allowed_groups"
|
||||||
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
|
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
|
||||||
@@ -117,7 +124,6 @@ var Columns = []string{
|
|||||||
FieldConcurrency,
|
FieldConcurrency,
|
||||||
FieldStatus,
|
FieldStatus,
|
||||||
FieldUsername,
|
FieldUsername,
|
||||||
FieldWechat,
|
|
||||||
FieldNotes,
|
FieldNotes,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,10 +177,6 @@ var (
|
|||||||
DefaultUsername string
|
DefaultUsername string
|
||||||
// UsernameValidator is a validator for the "username" field. It is called by the builders before save.
|
// UsernameValidator is a validator for the "username" field. It is called by the builders before save.
|
||||||
UsernameValidator func(string) error
|
UsernameValidator func(string) error
|
||||||
// DefaultWechat holds the default value on creation for the "wechat" field.
|
|
||||||
DefaultWechat string
|
|
||||||
// WechatValidator is a validator for the "wechat" field. It is called by the builders before save.
|
|
||||||
WechatValidator func(string) error
|
|
||||||
// DefaultNotes holds the default value on creation for the "notes" field.
|
// DefaultNotes holds the default value on creation for the "notes" field.
|
||||||
DefaultNotes string
|
DefaultNotes string
|
||||||
)
|
)
|
||||||
@@ -237,11 +239,6 @@ func ByUsername(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldUsername, opts...).ToFunc()
|
return sql.OrderByField(FieldUsername, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ByWechat orders the results by the wechat field.
|
|
||||||
func ByWechat(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldWechat, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByNotes orders the results by the notes field.
|
// ByNotes orders the results by the notes field.
|
||||||
func ByNotes(opts ...sql.OrderTermOption) OrderOption {
|
func ByNotes(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldNotes, opts...).ToFunc()
|
return sql.OrderByField(FieldNotes, opts...).ToFunc()
|
||||||
@@ -331,6 +328,20 @@ func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByAttributeValuesCount orders the results by attribute_values count.
|
||||||
|
func ByAttributeValuesCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return func(s *sql.Selector) {
|
||||||
|
sqlgraph.OrderByNeighborsCount(s, newAttributeValuesStep(), opts...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByAttributeValues orders the results by attribute_values terms.
|
||||||
|
func ByAttributeValues(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
|
||||||
|
return func(s *sql.Selector) {
|
||||||
|
sqlgraph.OrderByNeighborTerms(s, newAttributeValuesStep(), append([]sql.OrderTerm{term}, terms...)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
|
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
|
||||||
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
@@ -386,6 +397,13 @@ func newUsageLogsStep() *sqlgraph.Step {
|
|||||||
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
|
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
func newAttributeValuesStep() *sqlgraph.Step {
|
||||||
|
return sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(Table, FieldID),
|
||||||
|
sqlgraph.To(AttributeValuesInverseTable, FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.O2M, false, AttributeValuesTable, AttributeValuesColumn),
|
||||||
|
)
|
||||||
|
}
|
||||||
func newUserAllowedGroupsStep() *sqlgraph.Step {
|
func newUserAllowedGroupsStep() *sqlgraph.Step {
|
||||||
return sqlgraph.NewStep(
|
return sqlgraph.NewStep(
|
||||||
sqlgraph.From(Table, FieldID),
|
sqlgraph.From(Table, FieldID),
|
||||||
|
|||||||
@@ -105,11 +105,6 @@ func Username(v string) predicate.User {
|
|||||||
return predicate.User(sql.FieldEQ(FieldUsername, v))
|
return predicate.User(sql.FieldEQ(FieldUsername, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wechat applies equality check predicate on the "wechat" field. It's identical to WechatEQ.
|
|
||||||
func Wechat(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldEQ(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notes applies equality check predicate on the "notes" field. It's identical to NotesEQ.
|
// Notes applies equality check predicate on the "notes" field. It's identical to NotesEQ.
|
||||||
func Notes(v string) predicate.User {
|
func Notes(v string) predicate.User {
|
||||||
return predicate.User(sql.FieldEQ(FieldNotes, v))
|
return predicate.User(sql.FieldEQ(FieldNotes, v))
|
||||||
@@ -650,71 +645,6 @@ func UsernameContainsFold(v string) predicate.User {
|
|||||||
return predicate.User(sql.FieldContainsFold(FieldUsername, v))
|
return predicate.User(sql.FieldContainsFold(FieldUsername, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// WechatEQ applies the EQ predicate on the "wechat" field.
|
|
||||||
func WechatEQ(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldEQ(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatNEQ applies the NEQ predicate on the "wechat" field.
|
|
||||||
func WechatNEQ(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldNEQ(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatIn applies the In predicate on the "wechat" field.
|
|
||||||
func WechatIn(vs ...string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldIn(FieldWechat, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatNotIn applies the NotIn predicate on the "wechat" field.
|
|
||||||
func WechatNotIn(vs ...string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldNotIn(FieldWechat, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatGT applies the GT predicate on the "wechat" field.
|
|
||||||
func WechatGT(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldGT(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatGTE applies the GTE predicate on the "wechat" field.
|
|
||||||
func WechatGTE(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldGTE(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatLT applies the LT predicate on the "wechat" field.
|
|
||||||
func WechatLT(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldLT(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatLTE applies the LTE predicate on the "wechat" field.
|
|
||||||
func WechatLTE(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldLTE(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatContains applies the Contains predicate on the "wechat" field.
|
|
||||||
func WechatContains(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldContains(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatHasPrefix applies the HasPrefix predicate on the "wechat" field.
|
|
||||||
func WechatHasPrefix(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldHasPrefix(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatHasSuffix applies the HasSuffix predicate on the "wechat" field.
|
|
||||||
func WechatHasSuffix(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldHasSuffix(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatEqualFold applies the EqualFold predicate on the "wechat" field.
|
|
||||||
func WechatEqualFold(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldEqualFold(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WechatContainsFold applies the ContainsFold predicate on the "wechat" field.
|
|
||||||
func WechatContainsFold(v string) predicate.User {
|
|
||||||
return predicate.User(sql.FieldContainsFold(FieldWechat, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotesEQ applies the EQ predicate on the "notes" field.
|
// NotesEQ applies the EQ predicate on the "notes" field.
|
||||||
func NotesEQ(v string) predicate.User {
|
func NotesEQ(v string) predicate.User {
|
||||||
return predicate.User(sql.FieldEQ(FieldNotes, v))
|
return predicate.User(sql.FieldEQ(FieldNotes, v))
|
||||||
@@ -918,6 +848,29 @@ func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.User {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasAttributeValues applies the HasEdge predicate on the "attribute_values" edge.
|
||||||
|
func HasAttributeValues() predicate.User {
|
||||||
|
return predicate.User(func(s *sql.Selector) {
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(Table, FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.O2M, false, AttributeValuesTable, AttributeValuesColumn),
|
||||||
|
)
|
||||||
|
sqlgraph.HasNeighbors(s, step)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasAttributeValuesWith applies the HasEdge predicate on the "attribute_values" edge with a given conditions (other predicates).
|
||||||
|
func HasAttributeValuesWith(preds ...predicate.UserAttributeValue) predicate.User {
|
||||||
|
return predicate.User(func(s *sql.Selector) {
|
||||||
|
step := newAttributeValuesStep()
|
||||||
|
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||||
|
for _, p := range preds {
|
||||||
|
p(s)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
|
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
|
||||||
func HasUserAllowedGroups() predicate.User {
|
func HasUserAllowedGroups() predicate.User {
|
||||||
return predicate.User(func(s *sql.Selector) {
|
return predicate.User(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -151,20 +152,6 @@ func (_c *UserCreate) SetNillableUsername(v *string) *UserCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWechat sets the "wechat" field.
|
|
||||||
func (_c *UserCreate) SetWechat(v string) *UserCreate {
|
|
||||||
_c.mutation.SetWechat(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableWechat sets the "wechat" field if the given value is not nil.
|
|
||||||
func (_c *UserCreate) SetNillableWechat(v *string) *UserCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetWechat(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNotes sets the "notes" field.
|
// SetNotes sets the "notes" field.
|
||||||
func (_c *UserCreate) SetNotes(v string) *UserCreate {
|
func (_c *UserCreate) SetNotes(v string) *UserCreate {
|
||||||
_c.mutation.SetNotes(v)
|
_c.mutation.SetNotes(v)
|
||||||
@@ -269,6 +256,21 @@ func (_c *UserCreate) AddUsageLogs(v ...*UsageLog) *UserCreate {
|
|||||||
return _c.AddUsageLogIDs(ids...)
|
return _c.AddUsageLogIDs(ids...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddAttributeValueIDs adds the "attribute_values" edge to the UserAttributeValue entity by IDs.
|
||||||
|
func (_c *UserCreate) AddAttributeValueIDs(ids ...int64) *UserCreate {
|
||||||
|
_c.mutation.AddAttributeValueIDs(ids...)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAttributeValues adds the "attribute_values" edges to the UserAttributeValue entity.
|
||||||
|
func (_c *UserCreate) AddAttributeValues(v ...*UserAttributeValue) *UserCreate {
|
||||||
|
ids := make([]int64, len(v))
|
||||||
|
for i := range v {
|
||||||
|
ids[i] = v[i].ID
|
||||||
|
}
|
||||||
|
return _c.AddAttributeValueIDs(ids...)
|
||||||
|
}
|
||||||
|
|
||||||
// Mutation returns the UserMutation object of the builder.
|
// Mutation returns the UserMutation object of the builder.
|
||||||
func (_c *UserCreate) Mutation() *UserMutation {
|
func (_c *UserCreate) Mutation() *UserMutation {
|
||||||
return _c.mutation
|
return _c.mutation
|
||||||
@@ -340,10 +342,6 @@ func (_c *UserCreate) defaults() error {
|
|||||||
v := user.DefaultUsername
|
v := user.DefaultUsername
|
||||||
_c.mutation.SetUsername(v)
|
_c.mutation.SetUsername(v)
|
||||||
}
|
}
|
||||||
if _, ok := _c.mutation.Wechat(); !ok {
|
|
||||||
v := user.DefaultWechat
|
|
||||||
_c.mutation.SetWechat(v)
|
|
||||||
}
|
|
||||||
if _, ok := _c.mutation.Notes(); !ok {
|
if _, ok := _c.mutation.Notes(); !ok {
|
||||||
v := user.DefaultNotes
|
v := user.DefaultNotes
|
||||||
_c.mutation.SetNotes(v)
|
_c.mutation.SetNotes(v)
|
||||||
@@ -405,14 +403,6 @@ func (_c *UserCreate) check() error {
|
|||||||
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
|
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, ok := _c.mutation.Wechat(); !ok {
|
|
||||||
return &ValidationError{Name: "wechat", err: errors.New(`ent: missing required field "User.wechat"`)}
|
|
||||||
}
|
|
||||||
if v, ok := _c.mutation.Wechat(); ok {
|
|
||||||
if err := user.WechatValidator(v); err != nil {
|
|
||||||
return &ValidationError{Name: "wechat", err: fmt.Errorf(`ent: validator failed for field "User.wechat": %w`, err)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, ok := _c.mutation.Notes(); !ok {
|
if _, ok := _c.mutation.Notes(); !ok {
|
||||||
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
|
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
|
||||||
}
|
}
|
||||||
@@ -483,10 +473,6 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(user.FieldUsername, field.TypeString, value)
|
_spec.SetField(user.FieldUsername, field.TypeString, value)
|
||||||
_node.Username = value
|
_node.Username = value
|
||||||
}
|
}
|
||||||
if value, ok := _c.mutation.Wechat(); ok {
|
|
||||||
_spec.SetField(user.FieldWechat, field.TypeString, value)
|
|
||||||
_node.Wechat = value
|
|
||||||
}
|
|
||||||
if value, ok := _c.mutation.Notes(); ok {
|
if value, ok := _c.mutation.Notes(); ok {
|
||||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||||
_node.Notes = value
|
_node.Notes = value
|
||||||
@@ -591,6 +577,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
|||||||
}
|
}
|
||||||
_spec.Edges = append(_spec.Edges, edge)
|
_spec.Edges = append(_spec.Edges, edge)
|
||||||
}
|
}
|
||||||
|
if nodes := _c.mutation.AttributeValuesIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: user.AttributeValuesTable,
|
||||||
|
Columns: []string{user.AttributeValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges = append(_spec.Edges, edge)
|
||||||
|
}
|
||||||
return _node, _spec
|
return _node, _spec
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -769,18 +771,6 @@ func (u *UserUpsert) UpdateUsername() *UserUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWechat sets the "wechat" field.
|
|
||||||
func (u *UserUpsert) SetWechat(v string) *UserUpsert {
|
|
||||||
u.Set(user.FieldWechat, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateWechat sets the "wechat" field to the value that was provided on create.
|
|
||||||
func (u *UserUpsert) UpdateWechat() *UserUpsert {
|
|
||||||
u.SetExcluded(user.FieldWechat)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNotes sets the "notes" field.
|
// SetNotes sets the "notes" field.
|
||||||
func (u *UserUpsert) SetNotes(v string) *UserUpsert {
|
func (u *UserUpsert) SetNotes(v string) *UserUpsert {
|
||||||
u.Set(user.FieldNotes, v)
|
u.Set(user.FieldNotes, v)
|
||||||
@@ -985,20 +975,6 @@ func (u *UserUpsertOne) UpdateUsername() *UserUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWechat sets the "wechat" field.
|
|
||||||
func (u *UserUpsertOne) SetWechat(v string) *UserUpsertOne {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.SetWechat(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateWechat sets the "wechat" field to the value that was provided on create.
|
|
||||||
func (u *UserUpsertOne) UpdateWechat() *UserUpsertOne {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.UpdateWechat()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNotes sets the "notes" field.
|
// SetNotes sets the "notes" field.
|
||||||
func (u *UserUpsertOne) SetNotes(v string) *UserUpsertOne {
|
func (u *UserUpsertOne) SetNotes(v string) *UserUpsertOne {
|
||||||
return u.Update(func(s *UserUpsert) {
|
return u.Update(func(s *UserUpsert) {
|
||||||
@@ -1371,20 +1347,6 @@ func (u *UserUpsertBulk) UpdateUsername() *UserUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWechat sets the "wechat" field.
|
|
||||||
func (u *UserUpsertBulk) SetWechat(v string) *UserUpsertBulk {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.SetWechat(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateWechat sets the "wechat" field to the value that was provided on create.
|
|
||||||
func (u *UserUpsertBulk) UpdateWechat() *UserUpsertBulk {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.UpdateWechat()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNotes sets the "notes" field.
|
// SetNotes sets the "notes" field.
|
||||||
func (u *UserUpsertBulk) SetNotes(v string) *UserUpsertBulk {
|
func (u *UserUpsertBulk) SetNotes(v string) *UserUpsertBulk {
|
||||||
return u.Update(func(s *UserUpsert) {
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,6 +36,7 @@ type UserQuery struct {
|
|||||||
withAssignedSubscriptions *UserSubscriptionQuery
|
withAssignedSubscriptions *UserSubscriptionQuery
|
||||||
withAllowedGroups *GroupQuery
|
withAllowedGroups *GroupQuery
|
||||||
withUsageLogs *UsageLogQuery
|
withUsageLogs *UsageLogQuery
|
||||||
|
withAttributeValues *UserAttributeValueQuery
|
||||||
withUserAllowedGroups *UserAllowedGroupQuery
|
withUserAllowedGroups *UserAllowedGroupQuery
|
||||||
// intermediate query (i.e. traversal path).
|
// intermediate query (i.e. traversal path).
|
||||||
sql *sql.Selector
|
sql *sql.Selector
|
||||||
@@ -204,6 +206,28 @@ func (_q *UserQuery) QueryUsageLogs() *UsageLogQuery {
|
|||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryAttributeValues chains the current query on the "attribute_values" edge.
|
||||||
|
func (_q *UserQuery) QueryAttributeValues() *UserAttributeValueQuery {
|
||||||
|
query := (&UserAttributeValueClient{config: _q.config}).Query()
|
||||||
|
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||||
|
if err := _q.prepareQuery(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
selector := _q.sqlQuery(ctx)
|
||||||
|
if err := selector.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(user.Table, user.FieldID, selector),
|
||||||
|
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.O2M, false, user.AttributeValuesTable, user.AttributeValuesColumn),
|
||||||
|
)
|
||||||
|
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||||
|
return fromU, nil
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
|
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
|
||||||
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
|
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
|
||||||
query := (&UserAllowedGroupClient{config: _q.config}).Query()
|
query := (&UserAllowedGroupClient{config: _q.config}).Query()
|
||||||
@@ -424,6 +448,7 @@ func (_q *UserQuery) Clone() *UserQuery {
|
|||||||
withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(),
|
withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(),
|
||||||
withAllowedGroups: _q.withAllowedGroups.Clone(),
|
withAllowedGroups: _q.withAllowedGroups.Clone(),
|
||||||
withUsageLogs: _q.withUsageLogs.Clone(),
|
withUsageLogs: _q.withUsageLogs.Clone(),
|
||||||
|
withAttributeValues: _q.withAttributeValues.Clone(),
|
||||||
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
|
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
|
||||||
// clone intermediate query.
|
// clone intermediate query.
|
||||||
sql: _q.sql.Clone(),
|
sql: _q.sql.Clone(),
|
||||||
@@ -497,6 +522,17 @@ func (_q *UserQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserQuery {
|
|||||||
return _q
|
return _q
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithAttributeValues tells the query-builder to eager-load the nodes that are connected to
|
||||||
|
// the "attribute_values" edge. The optional arguments are used to configure the query builder of the edge.
|
||||||
|
func (_q *UserQuery) WithAttributeValues(opts ...func(*UserAttributeValueQuery)) *UserQuery {
|
||||||
|
query := (&UserAttributeValueClient{config: _q.config}).Query()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(query)
|
||||||
|
}
|
||||||
|
_q.withAttributeValues = query
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
|
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
|
||||||
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
|
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
|
||||||
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
|
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
|
||||||
@@ -586,13 +622,14 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
|||||||
var (
|
var (
|
||||||
nodes = []*User{}
|
nodes = []*User{}
|
||||||
_spec = _q.querySpec()
|
_spec = _q.querySpec()
|
||||||
loadedTypes = [7]bool{
|
loadedTypes = [8]bool{
|
||||||
_q.withAPIKeys != nil,
|
_q.withAPIKeys != nil,
|
||||||
_q.withRedeemCodes != nil,
|
_q.withRedeemCodes != nil,
|
||||||
_q.withSubscriptions != nil,
|
_q.withSubscriptions != nil,
|
||||||
_q.withAssignedSubscriptions != nil,
|
_q.withAssignedSubscriptions != nil,
|
||||||
_q.withAllowedGroups != nil,
|
_q.withAllowedGroups != nil,
|
||||||
_q.withUsageLogs != nil,
|
_q.withUsageLogs != nil,
|
||||||
|
_q.withAttributeValues != nil,
|
||||||
_q.withUserAllowedGroups != nil,
|
_q.withUserAllowedGroups != nil,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -658,6 +695,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if query := _q.withAttributeValues; query != nil {
|
||||||
|
if err := _q.loadAttributeValues(ctx, query, nodes,
|
||||||
|
func(n *User) { n.Edges.AttributeValues = []*UserAttributeValue{} },
|
||||||
|
func(n *User, e *UserAttributeValue) { n.Edges.AttributeValues = append(n.Edges.AttributeValues, e) }); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
if query := _q.withUserAllowedGroups; query != nil {
|
if query := _q.withUserAllowedGroups; query != nil {
|
||||||
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
|
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
|
||||||
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
|
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
|
||||||
@@ -885,6 +929,36 @@ func (_q *UserQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, no
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (_q *UserQuery) loadAttributeValues(ctx context.Context, query *UserAttributeValueQuery, nodes []*User, init func(*User), assign func(*User, *UserAttributeValue)) error {
|
||||||
|
fks := make([]driver.Value, 0, len(nodes))
|
||||||
|
nodeids := make(map[int64]*User)
|
||||||
|
for i := range nodes {
|
||||||
|
fks = append(fks, nodes[i].ID)
|
||||||
|
nodeids[nodes[i].ID] = nodes[i]
|
||||||
|
if init != nil {
|
||||||
|
init(nodes[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(query.ctx.Fields) > 0 {
|
||||||
|
query.ctx.AppendFieldOnce(userattributevalue.FieldUserID)
|
||||||
|
}
|
||||||
|
query.Where(predicate.UserAttributeValue(func(s *sql.Selector) {
|
||||||
|
s.Where(sql.InValues(s.C(user.AttributeValuesColumn), fks...))
|
||||||
|
}))
|
||||||
|
neighbors, err := query.All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, n := range neighbors {
|
||||||
|
fk := n.UserID
|
||||||
|
node, ok := nodeids[fk]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
|
||||||
|
}
|
||||||
|
assign(node, n)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
|
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
|
||||||
fks := make([]driver.Value, 0, len(nodes))
|
fks := make([]driver.Value, 0, len(nodes))
|
||||||
nodeids := make(map[int64]*User)
|
nodeids := make(map[int64]*User)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -171,20 +172,6 @@ func (_u *UserUpdate) SetNillableUsername(v *string) *UserUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWechat sets the "wechat" field.
|
|
||||||
func (_u *UserUpdate) SetWechat(v string) *UserUpdate {
|
|
||||||
_u.mutation.SetWechat(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableWechat sets the "wechat" field if the given value is not nil.
|
|
||||||
func (_u *UserUpdate) SetNillableWechat(v *string) *UserUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetWechat(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNotes sets the "notes" field.
|
// SetNotes sets the "notes" field.
|
||||||
func (_u *UserUpdate) SetNotes(v string) *UserUpdate {
|
func (_u *UserUpdate) SetNotes(v string) *UserUpdate {
|
||||||
_u.mutation.SetNotes(v)
|
_u.mutation.SetNotes(v)
|
||||||
@@ -289,6 +276,21 @@ func (_u *UserUpdate) AddUsageLogs(v ...*UsageLog) *UserUpdate {
|
|||||||
return _u.AddUsageLogIDs(ids...)
|
return _u.AddUsageLogIDs(ids...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddAttributeValueIDs adds the "attribute_values" edge to the UserAttributeValue entity by IDs.
|
||||||
|
func (_u *UserUpdate) AddAttributeValueIDs(ids ...int64) *UserUpdate {
|
||||||
|
_u.mutation.AddAttributeValueIDs(ids...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAttributeValues adds the "attribute_values" edges to the UserAttributeValue entity.
|
||||||
|
func (_u *UserUpdate) AddAttributeValues(v ...*UserAttributeValue) *UserUpdate {
|
||||||
|
ids := make([]int64, len(v))
|
||||||
|
for i := range v {
|
||||||
|
ids[i] = v[i].ID
|
||||||
|
}
|
||||||
|
return _u.AddAttributeValueIDs(ids...)
|
||||||
|
}
|
||||||
|
|
||||||
// Mutation returns the UserMutation object of the builder.
|
// Mutation returns the UserMutation object of the builder.
|
||||||
func (_u *UserUpdate) Mutation() *UserMutation {
|
func (_u *UserUpdate) Mutation() *UserMutation {
|
||||||
return _u.mutation
|
return _u.mutation
|
||||||
@@ -420,6 +422,27 @@ func (_u *UserUpdate) RemoveUsageLogs(v ...*UsageLog) *UserUpdate {
|
|||||||
return _u.RemoveUsageLogIDs(ids...)
|
return _u.RemoveUsageLogIDs(ids...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClearAttributeValues clears all "attribute_values" edges to the UserAttributeValue entity.
|
||||||
|
func (_u *UserUpdate) ClearAttributeValues() *UserUpdate {
|
||||||
|
_u.mutation.ClearAttributeValues()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAttributeValueIDs removes the "attribute_values" edge to UserAttributeValue entities by IDs.
|
||||||
|
func (_u *UserUpdate) RemoveAttributeValueIDs(ids ...int64) *UserUpdate {
|
||||||
|
_u.mutation.RemoveAttributeValueIDs(ids...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAttributeValues removes "attribute_values" edges to UserAttributeValue entities.
|
||||||
|
func (_u *UserUpdate) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdate {
|
||||||
|
ids := make([]int64, len(v))
|
||||||
|
for i := range v {
|
||||||
|
ids[i] = v[i].ID
|
||||||
|
}
|
||||||
|
return _u.RemoveAttributeValueIDs(ids...)
|
||||||
|
}
|
||||||
|
|
||||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||||
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
|
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
|
||||||
if err := _u.defaults(); err != nil {
|
if err := _u.defaults(); err != nil {
|
||||||
@@ -489,11 +512,6 @@ func (_u *UserUpdate) check() error {
|
|||||||
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
|
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v, ok := _u.mutation.Wechat(); ok {
|
|
||||||
if err := user.WechatValidator(v); err != nil {
|
|
||||||
return &ValidationError{Name: "wechat", err: fmt.Errorf(`ent: validator failed for field "User.wechat": %w`, err)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -545,9 +563,6 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.Username(); ok {
|
if value, ok := _u.mutation.Username(); ok {
|
||||||
_spec.SetField(user.FieldUsername, field.TypeString, value)
|
_spec.SetField(user.FieldUsername, field.TypeString, value)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.Wechat(); ok {
|
|
||||||
_spec.SetField(user.FieldWechat, field.TypeString, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.Notes(); ok {
|
if value, ok := _u.mutation.Notes(); ok {
|
||||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||||
}
|
}
|
||||||
@@ -833,6 +848,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
}
|
}
|
||||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||||
}
|
}
|
||||||
|
if _u.mutation.AttributeValuesCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: user.AttributeValuesTable,
|
||||||
|
Columns: []string{user.AttributeValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.RemovedAttributeValuesIDs(); len(nodes) > 0 && !_u.mutation.AttributeValuesCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: user.AttributeValuesTable,
|
||||||
|
Columns: []string{user.AttributeValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.AttributeValuesIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: user.AttributeValuesTable,
|
||||||
|
Columns: []string{user.AttributeValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||||
|
}
|
||||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||||
err = &NotFoundError{user.Label}
|
err = &NotFoundError{user.Label}
|
||||||
@@ -991,20 +1051,6 @@ func (_u *UserUpdateOne) SetNillableUsername(v *string) *UserUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWechat sets the "wechat" field.
|
|
||||||
func (_u *UserUpdateOne) SetWechat(v string) *UserUpdateOne {
|
|
||||||
_u.mutation.SetWechat(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableWechat sets the "wechat" field if the given value is not nil.
|
|
||||||
func (_u *UserUpdateOne) SetNillableWechat(v *string) *UserUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetWechat(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNotes sets the "notes" field.
|
// SetNotes sets the "notes" field.
|
||||||
func (_u *UserUpdateOne) SetNotes(v string) *UserUpdateOne {
|
func (_u *UserUpdateOne) SetNotes(v string) *UserUpdateOne {
|
||||||
_u.mutation.SetNotes(v)
|
_u.mutation.SetNotes(v)
|
||||||
@@ -1109,6 +1155,21 @@ func (_u *UserUpdateOne) AddUsageLogs(v ...*UsageLog) *UserUpdateOne {
|
|||||||
return _u.AddUsageLogIDs(ids...)
|
return _u.AddUsageLogIDs(ids...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddAttributeValueIDs adds the "attribute_values" edge to the UserAttributeValue entity by IDs.
|
||||||
|
func (_u *UserUpdateOne) AddAttributeValueIDs(ids ...int64) *UserUpdateOne {
|
||||||
|
_u.mutation.AddAttributeValueIDs(ids...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAttributeValues adds the "attribute_values" edges to the UserAttributeValue entity.
|
||||||
|
func (_u *UserUpdateOne) AddAttributeValues(v ...*UserAttributeValue) *UserUpdateOne {
|
||||||
|
ids := make([]int64, len(v))
|
||||||
|
for i := range v {
|
||||||
|
ids[i] = v[i].ID
|
||||||
|
}
|
||||||
|
return _u.AddAttributeValueIDs(ids...)
|
||||||
|
}
|
||||||
|
|
||||||
// Mutation returns the UserMutation object of the builder.
|
// Mutation returns the UserMutation object of the builder.
|
||||||
func (_u *UserUpdateOne) Mutation() *UserMutation {
|
func (_u *UserUpdateOne) Mutation() *UserMutation {
|
||||||
return _u.mutation
|
return _u.mutation
|
||||||
@@ -1240,6 +1301,27 @@ func (_u *UserUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserUpdateOne {
|
|||||||
return _u.RemoveUsageLogIDs(ids...)
|
return _u.RemoveUsageLogIDs(ids...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClearAttributeValues clears all "attribute_values" edges to the UserAttributeValue entity.
|
||||||
|
func (_u *UserUpdateOne) ClearAttributeValues() *UserUpdateOne {
|
||||||
|
_u.mutation.ClearAttributeValues()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAttributeValueIDs removes the "attribute_values" edge to UserAttributeValue entities by IDs.
|
||||||
|
func (_u *UserUpdateOne) RemoveAttributeValueIDs(ids ...int64) *UserUpdateOne {
|
||||||
|
_u.mutation.RemoveAttributeValueIDs(ids...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAttributeValues removes "attribute_values" edges to UserAttributeValue entities.
|
||||||
|
func (_u *UserUpdateOne) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdateOne {
|
||||||
|
ids := make([]int64, len(v))
|
||||||
|
for i := range v {
|
||||||
|
ids[i] = v[i].ID
|
||||||
|
}
|
||||||
|
return _u.RemoveAttributeValueIDs(ids...)
|
||||||
|
}
|
||||||
|
|
||||||
// Where appends a list predicates to the UserUpdate builder.
|
// Where appends a list predicates to the UserUpdate builder.
|
||||||
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
|
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
|
||||||
_u.mutation.Where(ps...)
|
_u.mutation.Where(ps...)
|
||||||
@@ -1322,11 +1404,6 @@ func (_u *UserUpdateOne) check() error {
|
|||||||
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
|
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v, ok := _u.mutation.Wechat(); ok {
|
|
||||||
if err := user.WechatValidator(v); err != nil {
|
|
||||||
return &ValidationError{Name: "wechat", err: fmt.Errorf(`ent: validator failed for field "User.wechat": %w`, err)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1395,9 +1472,6 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
|||||||
if value, ok := _u.mutation.Username(); ok {
|
if value, ok := _u.mutation.Username(); ok {
|
||||||
_spec.SetField(user.FieldUsername, field.TypeString, value)
|
_spec.SetField(user.FieldUsername, field.TypeString, value)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.Wechat(); ok {
|
|
||||||
_spec.SetField(user.FieldWechat, field.TypeString, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.Notes(); ok {
|
if value, ok := _u.mutation.Notes(); ok {
|
||||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||||
}
|
}
|
||||||
@@ -1683,6 +1757,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
|||||||
}
|
}
|
||||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||||
}
|
}
|
||||||
|
if _u.mutation.AttributeValuesCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: user.AttributeValuesTable,
|
||||||
|
Columns: []string{user.AttributeValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.RemovedAttributeValuesIDs(); len(nodes) > 0 && !_u.mutation.AttributeValuesCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: user.AttributeValuesTable,
|
||||||
|
Columns: []string{user.AttributeValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.AttributeValuesIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: user.AttributeValuesTable,
|
||||||
|
Columns: []string{user.AttributeValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||||
|
}
|
||||||
_node = &User{config: _u.config}
|
_node = &User{config: _u.config}
|
||||||
_spec.Assign = _node.assignValues
|
_spec.Assign = _node.assignValues
|
||||||
_spec.ScanValues = _node.scanValues
|
_spec.ScanValues = _node.scanValues
|
||||||
|
|||||||
276
backend/ent/userattributedefinition.go
Normal file
276
backend/ent/userattributedefinition.go
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent"
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeDefinition is the model entity for the UserAttributeDefinition schema.
|
||||||
|
type UserAttributeDefinition struct {
|
||||||
|
config `json:"-"`
|
||||||
|
// ID of the ent.
|
||||||
|
ID int64 `json:"id,omitempty"`
|
||||||
|
// CreatedAt holds the value of the "created_at" field.
|
||||||
|
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
|
// UpdatedAt holds the value of the "updated_at" field.
|
||||||
|
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||||
|
// DeletedAt holds the value of the "deleted_at" field.
|
||||||
|
DeletedAt *time.Time `json:"deleted_at,omitempty"`
|
||||||
|
// Key holds the value of the "key" field.
|
||||||
|
Key string `json:"key,omitempty"`
|
||||||
|
// Name holds the value of the "name" field.
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
// Description holds the value of the "description" field.
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
// Type holds the value of the "type" field.
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
// Options holds the value of the "options" field.
|
||||||
|
Options []map[string]interface{} `json:"options,omitempty"`
|
||||||
|
// Required holds the value of the "required" field.
|
||||||
|
Required bool `json:"required,omitempty"`
|
||||||
|
// Validation holds the value of the "validation" field.
|
||||||
|
Validation map[string]interface{} `json:"validation,omitempty"`
|
||||||
|
// Placeholder holds the value of the "placeholder" field.
|
||||||
|
Placeholder string `json:"placeholder,omitempty"`
|
||||||
|
// DisplayOrder holds the value of the "display_order" field.
|
||||||
|
DisplayOrder int `json:"display_order,omitempty"`
|
||||||
|
// Enabled holds the value of the "enabled" field.
|
||||||
|
Enabled bool `json:"enabled,omitempty"`
|
||||||
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
|
// The values are being populated by the UserAttributeDefinitionQuery when eager-loading is set.
|
||||||
|
Edges UserAttributeDefinitionEdges `json:"edges"`
|
||||||
|
selectValues sql.SelectValues
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeDefinitionEdges holds the relations/edges for other nodes in the graph.
|
||||||
|
type UserAttributeDefinitionEdges struct {
|
||||||
|
// Values holds the value of the values edge.
|
||||||
|
Values []*UserAttributeValue `json:"values,omitempty"`
|
||||||
|
// loadedTypes holds the information for reporting if a
|
||||||
|
// type was loaded (or requested) in eager-loading or not.
|
||||||
|
loadedTypes [1]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValuesOrErr returns the Values value or an error if the edge
|
||||||
|
// was not loaded in eager-loading.
|
||||||
|
func (e UserAttributeDefinitionEdges) ValuesOrErr() ([]*UserAttributeValue, error) {
|
||||||
|
if e.loadedTypes[0] {
|
||||||
|
return e.Values, nil
|
||||||
|
}
|
||||||
|
return nil, &NotLoadedError{edge: "values"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanValues returns the types for scanning values from sql.Rows.
|
||||||
|
func (*UserAttributeDefinition) scanValues(columns []string) ([]any, error) {
|
||||||
|
values := make([]any, len(columns))
|
||||||
|
for i := range columns {
|
||||||
|
switch columns[i] {
|
||||||
|
case userattributedefinition.FieldOptions, userattributedefinition.FieldValidation:
|
||||||
|
values[i] = new([]byte)
|
||||||
|
case userattributedefinition.FieldRequired, userattributedefinition.FieldEnabled:
|
||||||
|
values[i] = new(sql.NullBool)
|
||||||
|
case userattributedefinition.FieldID, userattributedefinition.FieldDisplayOrder:
|
||||||
|
values[i] = new(sql.NullInt64)
|
||||||
|
case userattributedefinition.FieldKey, userattributedefinition.FieldName, userattributedefinition.FieldDescription, userattributedefinition.FieldType, userattributedefinition.FieldPlaceholder:
|
||||||
|
values[i] = new(sql.NullString)
|
||||||
|
case userattributedefinition.FieldCreatedAt, userattributedefinition.FieldUpdatedAt, userattributedefinition.FieldDeletedAt:
|
||||||
|
values[i] = new(sql.NullTime)
|
||||||
|
default:
|
||||||
|
values[i] = new(sql.UnknownType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||||
|
// to the UserAttributeDefinition fields.
|
||||||
|
func (_m *UserAttributeDefinition) assignValues(columns []string, values []any) error {
|
||||||
|
if m, n := len(values), len(columns); m < n {
|
||||||
|
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
|
||||||
|
}
|
||||||
|
for i := range columns {
|
||||||
|
switch columns[i] {
|
||||||
|
case userattributedefinition.FieldID:
|
||||||
|
value, ok := values[i].(*sql.NullInt64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field id", value)
|
||||||
|
}
|
||||||
|
_m.ID = int64(value.Int64)
|
||||||
|
case userattributedefinition.FieldCreatedAt:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.CreatedAt = value.Time
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldUpdatedAt:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.UpdatedAt = value.Time
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldDeletedAt:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.DeletedAt = new(time.Time)
|
||||||
|
*_m.DeletedAt = value.Time
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldKey:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field key", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Key = value.String
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldName:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field name", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Name = value.String
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldDescription:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field description", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Description = value.String
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldType:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field type", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Type = value.String
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldOptions:
|
||||||
|
if value, ok := values[i].(*[]byte); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field options", values[i])
|
||||||
|
} else if value != nil && len(*value) > 0 {
|
||||||
|
if err := json.Unmarshal(*value, &_m.Options); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal field options: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldRequired:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field required", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Required = value.Bool
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldValidation:
|
||||||
|
if value, ok := values[i].(*[]byte); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field validation", values[i])
|
||||||
|
} else if value != nil && len(*value) > 0 {
|
||||||
|
if err := json.Unmarshal(*value, &_m.Validation); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal field validation: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldPlaceholder:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field placeholder", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Placeholder = value.String
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldDisplayOrder:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field display_order", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.DisplayOrder = int(value.Int64)
|
||||||
|
}
|
||||||
|
case userattributedefinition.FieldEnabled:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field enabled", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Enabled = value.Bool
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value returns the ent.Value that was dynamically selected and assigned to the UserAttributeDefinition.
|
||||||
|
// This includes values selected through modifiers, order, etc.
|
||||||
|
func (_m *UserAttributeDefinition) Value(name string) (ent.Value, error) {
|
||||||
|
return _m.selectValues.Get(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryValues queries the "values" edge of the UserAttributeDefinition entity.
|
||||||
|
func (_m *UserAttributeDefinition) QueryValues() *UserAttributeValueQuery {
|
||||||
|
return NewUserAttributeDefinitionClient(_m.config).QueryValues(_m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update returns a builder for updating this UserAttributeDefinition.
|
||||||
|
// Note that you need to call UserAttributeDefinition.Unwrap() before calling this method if this UserAttributeDefinition
|
||||||
|
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||||
|
func (_m *UserAttributeDefinition) Update() *UserAttributeDefinitionUpdateOne {
|
||||||
|
return NewUserAttributeDefinitionClient(_m.config).UpdateOne(_m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap unwraps the UserAttributeDefinition entity that was returned from a transaction after it was closed,
|
||||||
|
// so that all future queries will be executed through the driver which created the transaction.
|
||||||
|
func (_m *UserAttributeDefinition) Unwrap() *UserAttributeDefinition {
|
||||||
|
_tx, ok := _m.config.driver.(*txDriver)
|
||||||
|
if !ok {
|
||||||
|
panic("ent: UserAttributeDefinition is not a transactional entity")
|
||||||
|
}
|
||||||
|
_m.config.driver = _tx.drv
|
||||||
|
return _m
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements the fmt.Stringer.
|
||||||
|
func (_m *UserAttributeDefinition) String() string {
|
||||||
|
var builder strings.Builder
|
||||||
|
builder.WriteString("UserAttributeDefinition(")
|
||||||
|
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
|
||||||
|
builder.WriteString("created_at=")
|
||||||
|
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("updated_at=")
|
||||||
|
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.DeletedAt; v != nil {
|
||||||
|
builder.WriteString("deleted_at=")
|
||||||
|
builder.WriteString(v.Format(time.ANSIC))
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("key=")
|
||||||
|
builder.WriteString(_m.Key)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("name=")
|
||||||
|
builder.WriteString(_m.Name)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("description=")
|
||||||
|
builder.WriteString(_m.Description)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("type=")
|
||||||
|
builder.WriteString(_m.Type)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("options=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.Options))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("required=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.Required))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("validation=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.Validation))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("placeholder=")
|
||||||
|
builder.WriteString(_m.Placeholder)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("display_order=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.DisplayOrder))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("enabled=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.Enabled))
|
||||||
|
builder.WriteByte(')')
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeDefinitions is a parsable slice of UserAttributeDefinition.
|
||||||
|
type UserAttributeDefinitions []*UserAttributeDefinition
|
||||||
205
backend/ent/userattributedefinition/userattributedefinition.go
Normal file
205
backend/ent/userattributedefinition/userattributedefinition.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package userattributedefinition
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent"
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Label holds the string label denoting the userattributedefinition type in the database.
|
||||||
|
Label = "user_attribute_definition"
|
||||||
|
// FieldID holds the string denoting the id field in the database.
|
||||||
|
FieldID = "id"
|
||||||
|
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||||
|
FieldCreatedAt = "created_at"
|
||||||
|
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
||||||
|
FieldUpdatedAt = "updated_at"
|
||||||
|
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
|
||||||
|
FieldDeletedAt = "deleted_at"
|
||||||
|
// FieldKey holds the string denoting the key field in the database.
|
||||||
|
FieldKey = "key"
|
||||||
|
// FieldName holds the string denoting the name field in the database.
|
||||||
|
FieldName = "name"
|
||||||
|
// FieldDescription holds the string denoting the description field in the database.
|
||||||
|
FieldDescription = "description"
|
||||||
|
// FieldType holds the string denoting the type field in the database.
|
||||||
|
FieldType = "type"
|
||||||
|
// FieldOptions holds the string denoting the options field in the database.
|
||||||
|
FieldOptions = "options"
|
||||||
|
// FieldRequired holds the string denoting the required field in the database.
|
||||||
|
FieldRequired = "required"
|
||||||
|
// FieldValidation holds the string denoting the validation field in the database.
|
||||||
|
FieldValidation = "validation"
|
||||||
|
// FieldPlaceholder holds the string denoting the placeholder field in the database.
|
||||||
|
FieldPlaceholder = "placeholder"
|
||||||
|
// FieldDisplayOrder holds the string denoting the display_order field in the database.
|
||||||
|
FieldDisplayOrder = "display_order"
|
||||||
|
// FieldEnabled holds the string denoting the enabled field in the database.
|
||||||
|
FieldEnabled = "enabled"
|
||||||
|
// EdgeValues holds the string denoting the values edge name in mutations.
|
||||||
|
EdgeValues = "values"
|
||||||
|
// Table holds the table name of the userattributedefinition in the database.
|
||||||
|
Table = "user_attribute_definitions"
|
||||||
|
// ValuesTable is the table that holds the values relation/edge.
|
||||||
|
ValuesTable = "user_attribute_values"
|
||||||
|
// ValuesInverseTable is the table name for the UserAttributeValue entity.
|
||||||
|
// It exists in this package in order to avoid circular dependency with the "userattributevalue" package.
|
||||||
|
ValuesInverseTable = "user_attribute_values"
|
||||||
|
// ValuesColumn is the table column denoting the values relation/edge.
|
||||||
|
ValuesColumn = "attribute_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Columns holds all SQL columns for userattributedefinition fields.
|
||||||
|
var Columns = []string{
|
||||||
|
FieldID,
|
||||||
|
FieldCreatedAt,
|
||||||
|
FieldUpdatedAt,
|
||||||
|
FieldDeletedAt,
|
||||||
|
FieldKey,
|
||||||
|
FieldName,
|
||||||
|
FieldDescription,
|
||||||
|
FieldType,
|
||||||
|
FieldOptions,
|
||||||
|
FieldRequired,
|
||||||
|
FieldValidation,
|
||||||
|
FieldPlaceholder,
|
||||||
|
FieldDisplayOrder,
|
||||||
|
FieldEnabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||||
|
func ValidColumn(column string) bool {
|
||||||
|
for i := range Columns {
|
||||||
|
if column == Columns[i] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note that the variables below are initialized by the runtime
|
||||||
|
// package on the initialization of the application. Therefore,
|
||||||
|
// it should be imported in the main as follows:
|
||||||
|
//
|
||||||
|
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||||
|
var (
|
||||||
|
Hooks [1]ent.Hook
|
||||||
|
Interceptors [1]ent.Interceptor
|
||||||
|
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||||
|
DefaultCreatedAt func() time.Time
|
||||||
|
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||||
|
DefaultUpdatedAt func() time.Time
|
||||||
|
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
|
||||||
|
UpdateDefaultUpdatedAt func() time.Time
|
||||||
|
// KeyValidator is a validator for the "key" field. It is called by the builders before save.
|
||||||
|
KeyValidator func(string) error
|
||||||
|
// NameValidator is a validator for the "name" field. It is called by the builders before save.
|
||||||
|
NameValidator func(string) error
|
||||||
|
// DefaultDescription holds the default value on creation for the "description" field.
|
||||||
|
DefaultDescription string
|
||||||
|
// TypeValidator is a validator for the "type" field. It is called by the builders before save.
|
||||||
|
TypeValidator func(string) error
|
||||||
|
// DefaultOptions holds the default value on creation for the "options" field.
|
||||||
|
DefaultOptions []map[string]interface{}
|
||||||
|
// DefaultRequired holds the default value on creation for the "required" field.
|
||||||
|
DefaultRequired bool
|
||||||
|
// DefaultValidation holds the default value on creation for the "validation" field.
|
||||||
|
DefaultValidation map[string]interface{}
|
||||||
|
// DefaultPlaceholder holds the default value on creation for the "placeholder" field.
|
||||||
|
DefaultPlaceholder string
|
||||||
|
// PlaceholderValidator is a validator for the "placeholder" field. It is called by the builders before save.
|
||||||
|
PlaceholderValidator func(string) error
|
||||||
|
// DefaultDisplayOrder holds the default value on creation for the "display_order" field.
|
||||||
|
DefaultDisplayOrder int
|
||||||
|
// DefaultEnabled holds the default value on creation for the "enabled" field.
|
||||||
|
DefaultEnabled bool
|
||||||
|
)
|
||||||
|
|
||||||
|
// OrderOption defines the ordering options for the UserAttributeDefinition queries.
|
||||||
|
type OrderOption func(*sql.Selector)
|
||||||
|
|
||||||
|
// ByID orders the results by the id field.
|
||||||
|
func ByID(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldID, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByCreatedAt orders the results by the created_at field.
|
||||||
|
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByUpdatedAt orders the results by the updated_at field.
|
||||||
|
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByDeletedAt orders the results by the deleted_at field.
|
||||||
|
func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByKey orders the results by the key field.
|
||||||
|
func ByKey(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldKey, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByName orders the results by the name field.
|
||||||
|
func ByName(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldName, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByDescription orders the results by the description field.
|
||||||
|
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldDescription, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByType orders the results by the type field.
|
||||||
|
func ByType(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldType, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByRequired orders the results by the required field.
|
||||||
|
func ByRequired(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldRequired, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByPlaceholder orders the results by the placeholder field.
|
||||||
|
func ByPlaceholder(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldPlaceholder, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByDisplayOrder orders the results by the display_order field.
|
||||||
|
func ByDisplayOrder(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldDisplayOrder, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByEnabled orders the results by the enabled field.
|
||||||
|
func ByEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldEnabled, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByValuesCount orders the results by values count.
|
||||||
|
func ByValuesCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return func(s *sql.Selector) {
|
||||||
|
sqlgraph.OrderByNeighborsCount(s, newValuesStep(), opts...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByValues orders the results by values terms.
|
||||||
|
func ByValues(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
|
||||||
|
return func(s *sql.Selector) {
|
||||||
|
sqlgraph.OrderByNeighborTerms(s, newValuesStep(), append([]sql.OrderTerm{term}, terms...)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func newValuesStep() *sqlgraph.Step {
|
||||||
|
return sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(Table, FieldID),
|
||||||
|
sqlgraph.To(ValuesInverseTable, FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.O2M, false, ValuesTable, ValuesColumn),
|
||||||
|
)
|
||||||
|
}
|
||||||
664
backend/ent/userattributedefinition/where.go
Normal file
664
backend/ent/userattributedefinition/where.go
Normal file
@@ -0,0 +1,664 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package userattributedefinition
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ID filters vertices based on their ID field.
|
||||||
|
func ID(id int64) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDEQ applies the EQ predicate on the ID field.
|
||||||
|
func IDEQ(id int64) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDNEQ applies the NEQ predicate on the ID field.
|
||||||
|
func IDNEQ(id int64) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDIn applies the In predicate on the ID field.
|
||||||
|
func IDIn(ids ...int64) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIn(FieldID, ids...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDNotIn applies the NotIn predicate on the ID field.
|
||||||
|
func IDNotIn(ids ...int64) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldID, ids...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDGT applies the GT predicate on the ID field.
|
||||||
|
func IDGT(id int64) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGT(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDGTE applies the GTE predicate on the ID field.
|
||||||
|
func IDGTE(id int64) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDLT applies the LT predicate on the ID field.
|
||||||
|
func IDLT(id int64) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLT(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDLTE applies the LTE predicate on the ID field.
|
||||||
|
func IDLTE(id int64) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||||
|
func CreatedAt(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
|
||||||
|
func UpdatedAt(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
|
||||||
|
func DeletedAt(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDeletedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Key applies equality check predicate on the "key" field. It's identical to KeyEQ.
|
||||||
|
func Key(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
|
||||||
|
func Name(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
|
||||||
|
func Description(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type applies equality check predicate on the "type" field. It's identical to TypeEQ.
|
||||||
|
func Type(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required applies equality check predicate on the "required" field. It's identical to RequiredEQ.
|
||||||
|
func Required(v bool) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldRequired, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Placeholder applies equality check predicate on the "placeholder" field. It's identical to PlaceholderEQ.
|
||||||
|
func Placeholder(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayOrder applies equality check predicate on the "display_order" field. It's identical to DisplayOrderEQ.
|
||||||
|
func DisplayOrder(v int) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDisplayOrder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ.
|
||||||
|
func Enabled(v bool) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
|
func CreatedAtEQ(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
|
||||||
|
func CreatedAtNEQ(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtIn applies the In predicate on the "created_at" field.
|
||||||
|
func CreatedAtIn(vs ...time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIn(FieldCreatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
|
||||||
|
func CreatedAtNotIn(vs ...time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldCreatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtGT applies the GT predicate on the "created_at" field.
|
||||||
|
func CreatedAtGT(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGT(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
|
||||||
|
func CreatedAtGTE(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtLT applies the LT predicate on the "created_at" field.
|
||||||
|
func CreatedAtLT(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLT(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
|
||||||
|
func CreatedAtLTE(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtEQ(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtNEQ(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtIn applies the In predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtIn(vs ...time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIn(FieldUpdatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtNotIn(vs ...time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldUpdatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtGT(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGT(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtGTE(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtLT(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLT(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtLTE(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
|
||||||
|
func DeletedAtEQ(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDeletedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
|
||||||
|
func DeletedAtNEQ(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldDeletedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAtIn applies the In predicate on the "deleted_at" field.
|
||||||
|
func DeletedAtIn(vs ...time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIn(FieldDeletedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
|
||||||
|
func DeletedAtNotIn(vs ...time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldDeletedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
|
||||||
|
func DeletedAtGT(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGT(FieldDeletedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
|
||||||
|
func DeletedAtGTE(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldDeletedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
|
||||||
|
func DeletedAtLT(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLT(FieldDeletedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
|
||||||
|
func DeletedAtLTE(v time.Time) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldDeletedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
|
||||||
|
func DeletedAtIsNil() predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIsNull(FieldDeletedAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
|
||||||
|
func DeletedAtNotNil() predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotNull(FieldDeletedAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyEQ applies the EQ predicate on the "key" field.
|
||||||
|
func KeyEQ(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyNEQ applies the NEQ predicate on the "key" field.
|
||||||
|
func KeyNEQ(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyIn applies the In predicate on the "key" field.
|
||||||
|
func KeyIn(vs ...string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIn(FieldKey, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyNotIn applies the NotIn predicate on the "key" field.
|
||||||
|
func KeyNotIn(vs ...string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldKey, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyGT applies the GT predicate on the "key" field.
|
||||||
|
func KeyGT(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGT(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyGTE applies the GTE predicate on the "key" field.
|
||||||
|
func KeyGTE(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyLT applies the LT predicate on the "key" field.
|
||||||
|
func KeyLT(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLT(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyLTE applies the LTE predicate on the "key" field.
|
||||||
|
func KeyLTE(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyContains applies the Contains predicate on the "key" field.
|
||||||
|
func KeyContains(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldContains(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyHasPrefix applies the HasPrefix predicate on the "key" field.
|
||||||
|
func KeyHasPrefix(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyHasSuffix applies the HasSuffix predicate on the "key" field.
|
||||||
|
func KeyHasSuffix(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyEqualFold applies the EqualFold predicate on the "key" field.
|
||||||
|
func KeyEqualFold(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyContainsFold applies the ContainsFold predicate on the "key" field.
|
||||||
|
func KeyContainsFold(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldKey, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameEQ applies the EQ predicate on the "name" field.
|
||||||
|
func NameEQ(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameNEQ applies the NEQ predicate on the "name" field.
|
||||||
|
func NameNEQ(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameIn applies the In predicate on the "name" field.
|
||||||
|
func NameIn(vs ...string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIn(FieldName, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameNotIn applies the NotIn predicate on the "name" field.
|
||||||
|
func NameNotIn(vs ...string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldName, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameGT applies the GT predicate on the "name" field.
|
||||||
|
func NameGT(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGT(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameGTE applies the GTE predicate on the "name" field.
|
||||||
|
func NameGTE(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameLT applies the LT predicate on the "name" field.
|
||||||
|
func NameLT(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLT(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameLTE applies the LTE predicate on the "name" field.
|
||||||
|
func NameLTE(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameContains applies the Contains predicate on the "name" field.
|
||||||
|
func NameContains(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldContains(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameHasPrefix applies the HasPrefix predicate on the "name" field.
|
||||||
|
func NameHasPrefix(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameHasSuffix applies the HasSuffix predicate on the "name" field.
|
||||||
|
func NameHasSuffix(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameEqualFold applies the EqualFold predicate on the "name" field.
|
||||||
|
func NameEqualFold(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NameContainsFold applies the ContainsFold predicate on the "name" field.
|
||||||
|
func NameContainsFold(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldName, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionEQ applies the EQ predicate on the "description" field.
|
||||||
|
func DescriptionEQ(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionNEQ applies the NEQ predicate on the "description" field.
|
||||||
|
func DescriptionNEQ(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionIn applies the In predicate on the "description" field.
|
||||||
|
func DescriptionIn(vs ...string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIn(FieldDescription, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionNotIn applies the NotIn predicate on the "description" field.
|
||||||
|
func DescriptionNotIn(vs ...string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldDescription, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionGT applies the GT predicate on the "description" field.
|
||||||
|
func DescriptionGT(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGT(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionGTE applies the GTE predicate on the "description" field.
|
||||||
|
func DescriptionGTE(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionLT applies the LT predicate on the "description" field.
|
||||||
|
func DescriptionLT(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLT(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionLTE applies the LTE predicate on the "description" field.
|
||||||
|
func DescriptionLTE(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionContains applies the Contains predicate on the "description" field.
|
||||||
|
func DescriptionContains(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldContains(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
|
||||||
|
func DescriptionHasPrefix(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
|
||||||
|
func DescriptionHasSuffix(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
|
||||||
|
func DescriptionEqualFold(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
|
||||||
|
func DescriptionContainsFold(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldDescription, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeEQ applies the EQ predicate on the "type" field.
|
||||||
|
func TypeEQ(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeNEQ applies the NEQ predicate on the "type" field.
|
||||||
|
func TypeNEQ(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeIn applies the In predicate on the "type" field.
|
||||||
|
func TypeIn(vs ...string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIn(FieldType, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeNotIn applies the NotIn predicate on the "type" field.
|
||||||
|
func TypeNotIn(vs ...string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldType, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeGT applies the GT predicate on the "type" field.
|
||||||
|
func TypeGT(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGT(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeGTE applies the GTE predicate on the "type" field.
|
||||||
|
func TypeGTE(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeLT applies the LT predicate on the "type" field.
|
||||||
|
func TypeLT(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLT(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeLTE applies the LTE predicate on the "type" field.
|
||||||
|
func TypeLTE(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeContains applies the Contains predicate on the "type" field.
|
||||||
|
func TypeContains(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldContains(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeHasPrefix applies the HasPrefix predicate on the "type" field.
|
||||||
|
func TypeHasPrefix(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeHasSuffix applies the HasSuffix predicate on the "type" field.
|
||||||
|
func TypeHasSuffix(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeEqualFold applies the EqualFold predicate on the "type" field.
|
||||||
|
func TypeEqualFold(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeContainsFold applies the ContainsFold predicate on the "type" field.
|
||||||
|
func TypeContainsFold(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequiredEQ applies the EQ predicate on the "required" field.
|
||||||
|
func RequiredEQ(v bool) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldRequired, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequiredNEQ applies the NEQ predicate on the "required" field.
|
||||||
|
func RequiredNEQ(v bool) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldRequired, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderEQ applies the EQ predicate on the "placeholder" field.
|
||||||
|
func PlaceholderEQ(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderNEQ applies the NEQ predicate on the "placeholder" field.
|
||||||
|
func PlaceholderNEQ(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderIn applies the In predicate on the "placeholder" field.
|
||||||
|
func PlaceholderIn(vs ...string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIn(FieldPlaceholder, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderNotIn applies the NotIn predicate on the "placeholder" field.
|
||||||
|
func PlaceholderNotIn(vs ...string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldPlaceholder, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderGT applies the GT predicate on the "placeholder" field.
|
||||||
|
func PlaceholderGT(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGT(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderGTE applies the GTE predicate on the "placeholder" field.
|
||||||
|
func PlaceholderGTE(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderLT applies the LT predicate on the "placeholder" field.
|
||||||
|
func PlaceholderLT(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLT(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderLTE applies the LTE predicate on the "placeholder" field.
|
||||||
|
func PlaceholderLTE(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderContains applies the Contains predicate on the "placeholder" field.
|
||||||
|
func PlaceholderContains(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldContains(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderHasPrefix applies the HasPrefix predicate on the "placeholder" field.
|
||||||
|
func PlaceholderHasPrefix(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderHasSuffix applies the HasSuffix predicate on the "placeholder" field.
|
||||||
|
func PlaceholderHasSuffix(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderEqualFold applies the EqualFold predicate on the "placeholder" field.
|
||||||
|
func PlaceholderEqualFold(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderContainsFold applies the ContainsFold predicate on the "placeholder" field.
|
||||||
|
func PlaceholderContainsFold(v string) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldPlaceholder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayOrderEQ applies the EQ predicate on the "display_order" field.
|
||||||
|
func DisplayOrderEQ(v int) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDisplayOrder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayOrderNEQ applies the NEQ predicate on the "display_order" field.
|
||||||
|
func DisplayOrderNEQ(v int) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldDisplayOrder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayOrderIn applies the In predicate on the "display_order" field.
|
||||||
|
func DisplayOrderIn(vs ...int) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldIn(FieldDisplayOrder, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayOrderNotIn applies the NotIn predicate on the "display_order" field.
|
||||||
|
func DisplayOrderNotIn(vs ...int) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldDisplayOrder, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayOrderGT applies the GT predicate on the "display_order" field.
|
||||||
|
func DisplayOrderGT(v int) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGT(FieldDisplayOrder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayOrderGTE applies the GTE predicate on the "display_order" field.
|
||||||
|
func DisplayOrderGTE(v int) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldDisplayOrder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayOrderLT applies the LT predicate on the "display_order" field.
|
||||||
|
func DisplayOrderLT(v int) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLT(FieldDisplayOrder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayOrderLTE applies the LTE predicate on the "display_order" field.
|
||||||
|
func DisplayOrderLTE(v int) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldDisplayOrder, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnabledEQ applies the EQ predicate on the "enabled" field.
|
||||||
|
func EnabledEQ(v bool) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnabledNEQ applies the NEQ predicate on the "enabled" field.
|
||||||
|
func EnabledNEQ(v bool) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasValues applies the HasEdge predicate on the "values" edge.
|
||||||
|
func HasValues() predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(func(s *sql.Selector) {
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(Table, FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.O2M, false, ValuesTable, ValuesColumn),
|
||||||
|
)
|
||||||
|
sqlgraph.HasNeighbors(s, step)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasValuesWith applies the HasEdge predicate on the "values" edge with a given conditions (other predicates).
|
||||||
|
func HasValuesWith(preds ...predicate.UserAttributeValue) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(func(s *sql.Selector) {
|
||||||
|
step := newValuesStep()
|
||||||
|
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||||
|
for _, p := range preds {
|
||||||
|
p(s)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// And groups predicates with the AND operator between them.
|
||||||
|
func And(predicates ...predicate.UserAttributeDefinition) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.AndPredicates(predicates...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Or groups predicates with the OR operator between them.
|
||||||
|
func Or(predicates ...predicate.UserAttributeDefinition) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.OrPredicates(predicates...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not applies the not operator on the given predicate.
|
||||||
|
func Not(p predicate.UserAttributeDefinition) predicate.UserAttributeDefinition {
|
||||||
|
return predicate.UserAttributeDefinition(sql.NotPredicates(p))
|
||||||
|
}
|
||||||
1267
backend/ent/userattributedefinition_create.go
Normal file
1267
backend/ent/userattributedefinition_create.go
Normal file
File diff suppressed because it is too large
Load Diff
88
backend/ent/userattributedefinition_delete.go
Normal file
88
backend/ent/userattributedefinition_delete.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeDefinitionDelete is the builder for deleting a UserAttributeDefinition entity.
|
||||||
|
type UserAttributeDefinitionDelete struct {
|
||||||
|
config
|
||||||
|
hooks []Hook
|
||||||
|
mutation *UserAttributeDefinitionMutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the UserAttributeDefinitionDelete builder.
|
||||||
|
func (_d *UserAttributeDefinitionDelete) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionDelete {
|
||||||
|
_d.mutation.Where(ps...)
|
||||||
|
return _d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||||
|
func (_d *UserAttributeDefinitionDelete) Exec(ctx context.Context) (int, error) {
|
||||||
|
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_d *UserAttributeDefinitionDelete) ExecX(ctx context.Context) int {
|
||||||
|
n, err := _d.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_d *UserAttributeDefinitionDelete) sqlExec(ctx context.Context) (int, error) {
|
||||||
|
_spec := sqlgraph.NewDeleteSpec(userattributedefinition.Table, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
|
||||||
|
if ps := _d.mutation.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
|
||||||
|
if err != nil && sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
_d.mutation.done = true
|
||||||
|
return affected, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeDefinitionDeleteOne is the builder for deleting a single UserAttributeDefinition entity.
|
||||||
|
type UserAttributeDefinitionDeleteOne struct {
|
||||||
|
_d *UserAttributeDefinitionDelete
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the UserAttributeDefinitionDelete builder.
|
||||||
|
func (_d *UserAttributeDefinitionDeleteOne) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionDeleteOne {
|
||||||
|
_d._d.mutation.Where(ps...)
|
||||||
|
return _d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the deletion query.
|
||||||
|
func (_d *UserAttributeDefinitionDeleteOne) Exec(ctx context.Context) error {
|
||||||
|
n, err := _d._d.Exec(ctx)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return err
|
||||||
|
case n == 0:
|
||||||
|
return &NotFoundError{userattributedefinition.Label}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_d *UserAttributeDefinitionDeleteOne) ExecX(ctx context.Context) {
|
||||||
|
if err := _d.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
606
backend/ent/userattributedefinition_query.go
Normal file
606
backend/ent/userattributedefinition_query.go
Normal file
@@ -0,0 +1,606 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"entgo.io/ent"
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeDefinitionQuery is the builder for querying UserAttributeDefinition entities.
|
||||||
|
type UserAttributeDefinitionQuery struct {
|
||||||
|
config
|
||||||
|
ctx *QueryContext
|
||||||
|
order []userattributedefinition.OrderOption
|
||||||
|
inters []Interceptor
|
||||||
|
predicates []predicate.UserAttributeDefinition
|
||||||
|
withValues *UserAttributeValueQuery
|
||||||
|
// intermediate query (i.e. traversal path).
|
||||||
|
sql *sql.Selector
|
||||||
|
path func(context.Context) (*sql.Selector, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where adds a new predicate for the UserAttributeDefinitionQuery builder.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionQuery {
|
||||||
|
_q.predicates = append(_q.predicates, ps...)
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit the number of records to be returned by this query.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Limit(limit int) *UserAttributeDefinitionQuery {
|
||||||
|
_q.ctx.Limit = &limit
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offset to start from.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Offset(offset int) *UserAttributeDefinitionQuery {
|
||||||
|
_q.ctx.Offset = &offset
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unique configures the query builder to filter duplicate records on query.
|
||||||
|
// By default, unique is set to true, and can be disabled using this method.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Unique(unique bool) *UserAttributeDefinitionQuery {
|
||||||
|
_q.ctx.Unique = &unique
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Order specifies how the records should be ordered.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Order(o ...userattributedefinition.OrderOption) *UserAttributeDefinitionQuery {
|
||||||
|
_q.order = append(_q.order, o...)
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryValues chains the current query on the "values" edge.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) QueryValues() *UserAttributeValueQuery {
|
||||||
|
query := (&UserAttributeValueClient{config: _q.config}).Query()
|
||||||
|
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||||
|
if err := _q.prepareQuery(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
selector := _q.sqlQuery(ctx)
|
||||||
|
if err := selector.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(userattributedefinition.Table, userattributedefinition.FieldID, selector),
|
||||||
|
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.O2M, false, userattributedefinition.ValuesTable, userattributedefinition.ValuesColumn),
|
||||||
|
)
|
||||||
|
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||||
|
return fromU, nil
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// First returns the first UserAttributeDefinition entity from the query.
|
||||||
|
// Returns a *NotFoundError when no UserAttributeDefinition was found.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) First(ctx context.Context) (*UserAttributeDefinition, error) {
|
||||||
|
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
return nil, &NotFoundError{userattributedefinition.Label}
|
||||||
|
}
|
||||||
|
return nodes[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstX is like First, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) FirstX(ctx context.Context) *UserAttributeDefinition {
|
||||||
|
node, err := _q.First(ctx)
|
||||||
|
if err != nil && !IsNotFound(err) {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstID returns the first UserAttributeDefinition ID from the query.
|
||||||
|
// Returns a *NotFoundError when no UserAttributeDefinition ID was found.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) FirstID(ctx context.Context) (id int64, err error) {
|
||||||
|
var ids []int64
|
||||||
|
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
err = &NotFoundError{userattributedefinition.Label}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return ids[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) FirstIDX(ctx context.Context) int64 {
|
||||||
|
id, err := _q.FirstID(ctx)
|
||||||
|
if err != nil && !IsNotFound(err) {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only returns a single UserAttributeDefinition entity found by the query, ensuring it only returns one.
|
||||||
|
// Returns a *NotSingularError when more than one UserAttributeDefinition entity is found.
|
||||||
|
// Returns a *NotFoundError when no UserAttributeDefinition entities are found.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Only(ctx context.Context) (*UserAttributeDefinition, error) {
|
||||||
|
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch len(nodes) {
|
||||||
|
case 1:
|
||||||
|
return nodes[0], nil
|
||||||
|
case 0:
|
||||||
|
return nil, &NotFoundError{userattributedefinition.Label}
|
||||||
|
default:
|
||||||
|
return nil, &NotSingularError{userattributedefinition.Label}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlyX is like Only, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) OnlyX(ctx context.Context) *UserAttributeDefinition {
|
||||||
|
node, err := _q.Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlyID is like Only, but returns the only UserAttributeDefinition ID in the query.
|
||||||
|
// Returns a *NotSingularError when more than one UserAttributeDefinition ID is found.
|
||||||
|
// Returns a *NotFoundError when no entities are found.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) OnlyID(ctx context.Context) (id int64, err error) {
|
||||||
|
var ids []int64
|
||||||
|
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch len(ids) {
|
||||||
|
case 1:
|
||||||
|
id = ids[0]
|
||||||
|
case 0:
|
||||||
|
err = &NotFoundError{userattributedefinition.Label}
|
||||||
|
default:
|
||||||
|
err = &NotSingularError{userattributedefinition.Label}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) OnlyIDX(ctx context.Context) int64 {
|
||||||
|
id, err := _q.OnlyID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// All executes the query and returns a list of UserAttributeDefinitions.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) All(ctx context.Context) ([]*UserAttributeDefinition, error) {
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
|
||||||
|
if err := _q.prepareQuery(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
qr := querierAll[[]*UserAttributeDefinition, *UserAttributeDefinitionQuery]()
|
||||||
|
return withInterceptors[[]*UserAttributeDefinition](ctx, _q, qr, _q.inters)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllX is like All, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) AllX(ctx context.Context) []*UserAttributeDefinition {
|
||||||
|
nodes, err := _q.All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDs executes the query and returns a list of UserAttributeDefinition IDs.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) IDs(ctx context.Context) (ids []int64, err error) {
|
||||||
|
if _q.ctx.Unique == nil && _q.path != nil {
|
||||||
|
_q.Unique(true)
|
||||||
|
}
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
|
||||||
|
if err = _q.Select(userattributedefinition.FieldID).Scan(ctx, &ids); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDsX is like IDs, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) IDsX(ctx context.Context) []int64 {
|
||||||
|
ids, err := _q.IDs(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the count of the given query.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Count(ctx context.Context) (int, error) {
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
|
||||||
|
if err := _q.prepareQuery(ctx); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return withInterceptors[int](ctx, _q, querierCount[*UserAttributeDefinitionQuery](), _q.inters)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountX is like Count, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) CountX(ctx context.Context) int {
|
||||||
|
count, err := _q.Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exist returns true if the query has elements in the graph.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Exist(ctx context.Context) (bool, error) {
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
|
||||||
|
switch _, err := _q.FirstID(ctx); {
|
||||||
|
case IsNotFound(err):
|
||||||
|
return false, nil
|
||||||
|
case err != nil:
|
||||||
|
return false, fmt.Errorf("ent: check existence: %w", err)
|
||||||
|
default:
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExistX is like Exist, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) ExistX(ctx context.Context) bool {
|
||||||
|
exist, err := _q.Exist(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return exist
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone returns a duplicate of the UserAttributeDefinitionQuery builder, including all associated steps. It can be
|
||||||
|
// used to prepare common query builders and use them differently after the clone is made.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Clone() *UserAttributeDefinitionQuery {
|
||||||
|
if _q == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UserAttributeDefinitionQuery{
|
||||||
|
config: _q.config,
|
||||||
|
ctx: _q.ctx.Clone(),
|
||||||
|
order: append([]userattributedefinition.OrderOption{}, _q.order...),
|
||||||
|
inters: append([]Interceptor{}, _q.inters...),
|
||||||
|
predicates: append([]predicate.UserAttributeDefinition{}, _q.predicates...),
|
||||||
|
withValues: _q.withValues.Clone(),
|
||||||
|
// clone intermediate query.
|
||||||
|
sql: _q.sql.Clone(),
|
||||||
|
path: _q.path,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithValues tells the query-builder to eager-load the nodes that are connected to
|
||||||
|
// the "values" edge. The optional arguments are used to configure the query builder of the edge.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) WithValues(opts ...func(*UserAttributeValueQuery)) *UserAttributeDefinitionQuery {
|
||||||
|
query := (&UserAttributeValueClient{config: _q.config}).Query()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(query)
|
||||||
|
}
|
||||||
|
_q.withValues = query
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupBy is used to group vertices by one or more fields/columns.
|
||||||
|
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// var v []struct {
|
||||||
|
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
|
// Count int `json:"count,omitempty"`
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// client.UserAttributeDefinition.Query().
|
||||||
|
// GroupBy(userattributedefinition.FieldCreatedAt).
|
||||||
|
// Aggregate(ent.Count()).
|
||||||
|
// Scan(ctx, &v)
|
||||||
|
func (_q *UserAttributeDefinitionQuery) GroupBy(field string, fields ...string) *UserAttributeDefinitionGroupBy {
|
||||||
|
_q.ctx.Fields = append([]string{field}, fields...)
|
||||||
|
grbuild := &UserAttributeDefinitionGroupBy{build: _q}
|
||||||
|
grbuild.flds = &_q.ctx.Fields
|
||||||
|
grbuild.label = userattributedefinition.Label
|
||||||
|
grbuild.scan = grbuild.Scan
|
||||||
|
return grbuild
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select allows the selection one or more fields/columns for the given query,
|
||||||
|
// instead of selecting all fields in the entity.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// var v []struct {
|
||||||
|
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// client.UserAttributeDefinition.Query().
|
||||||
|
// Select(userattributedefinition.FieldCreatedAt).
|
||||||
|
// Scan(ctx, &v)
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Select(fields ...string) *UserAttributeDefinitionSelect {
|
||||||
|
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
|
||||||
|
sbuild := &UserAttributeDefinitionSelect{UserAttributeDefinitionQuery: _q}
|
||||||
|
sbuild.label = userattributedefinition.Label
|
||||||
|
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
|
||||||
|
return sbuild
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate returns a UserAttributeDefinitionSelect configured with the given aggregations.
|
||||||
|
func (_q *UserAttributeDefinitionQuery) Aggregate(fns ...AggregateFunc) *UserAttributeDefinitionSelect {
|
||||||
|
return _q.Select().Aggregate(fns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeDefinitionQuery) prepareQuery(ctx context.Context) error {
|
||||||
|
for _, inter := range _q.inters {
|
||||||
|
if inter == nil {
|
||||||
|
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
|
||||||
|
}
|
||||||
|
if trv, ok := inter.(Traverser); ok {
|
||||||
|
if err := trv.Traverse(ctx, _q); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, f := range _q.ctx.Fields {
|
||||||
|
if !userattributedefinition.ValidColumn(f) {
|
||||||
|
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _q.path != nil {
|
||||||
|
prev, err := _q.path(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_q.sql = prev
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeDefinitionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserAttributeDefinition, error) {
|
||||||
|
var (
|
||||||
|
nodes = []*UserAttributeDefinition{}
|
||||||
|
_spec = _q.querySpec()
|
||||||
|
loadedTypes = [1]bool{
|
||||||
|
_q.withValues != nil,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||||
|
return (*UserAttributeDefinition).scanValues(nil, columns)
|
||||||
|
}
|
||||||
|
_spec.Assign = func(columns []string, values []any) error {
|
||||||
|
node := &UserAttributeDefinition{config: _q.config}
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
node.Edges.loadedTypes = loadedTypes
|
||||||
|
return node.assignValues(columns, values)
|
||||||
|
}
|
||||||
|
for i := range hooks {
|
||||||
|
hooks[i](ctx, _spec)
|
||||||
|
}
|
||||||
|
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
if query := _q.withValues; query != nil {
|
||||||
|
if err := _q.loadValues(ctx, query, nodes,
|
||||||
|
func(n *UserAttributeDefinition) { n.Edges.Values = []*UserAttributeValue{} },
|
||||||
|
func(n *UserAttributeDefinition, e *UserAttributeValue) { n.Edges.Values = append(n.Edges.Values, e) }); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeDefinitionQuery) loadValues(ctx context.Context, query *UserAttributeValueQuery, nodes []*UserAttributeDefinition, init func(*UserAttributeDefinition), assign func(*UserAttributeDefinition, *UserAttributeValue)) error {
|
||||||
|
fks := make([]driver.Value, 0, len(nodes))
|
||||||
|
nodeids := make(map[int64]*UserAttributeDefinition)
|
||||||
|
for i := range nodes {
|
||||||
|
fks = append(fks, nodes[i].ID)
|
||||||
|
nodeids[nodes[i].ID] = nodes[i]
|
||||||
|
if init != nil {
|
||||||
|
init(nodes[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(query.ctx.Fields) > 0 {
|
||||||
|
query.ctx.AppendFieldOnce(userattributevalue.FieldAttributeID)
|
||||||
|
}
|
||||||
|
query.Where(predicate.UserAttributeValue(func(s *sql.Selector) {
|
||||||
|
s.Where(sql.InValues(s.C(userattributedefinition.ValuesColumn), fks...))
|
||||||
|
}))
|
||||||
|
neighbors, err := query.All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, n := range neighbors {
|
||||||
|
fk := n.AttributeID
|
||||||
|
node, ok := nodeids[fk]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf(`unexpected referenced foreign-key "attribute_id" returned %v for node %v`, fk, n.ID)
|
||||||
|
}
|
||||||
|
assign(node, n)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeDefinitionQuery) sqlCount(ctx context.Context) (int, error) {
|
||||||
|
_spec := _q.querySpec()
|
||||||
|
_spec.Node.Columns = _q.ctx.Fields
|
||||||
|
if len(_q.ctx.Fields) > 0 {
|
||||||
|
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||||
|
}
|
||||||
|
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeDefinitionQuery) querySpec() *sqlgraph.QuerySpec {
|
||||||
|
_spec := sqlgraph.NewQuerySpec(userattributedefinition.Table, userattributedefinition.Columns, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
|
||||||
|
_spec.From = _q.sql
|
||||||
|
if unique := _q.ctx.Unique; unique != nil {
|
||||||
|
_spec.Unique = *unique
|
||||||
|
} else if _q.path != nil {
|
||||||
|
_spec.Unique = true
|
||||||
|
}
|
||||||
|
if fields := _q.ctx.Fields; len(fields) > 0 {
|
||||||
|
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, userattributedefinition.FieldID)
|
||||||
|
for i := range fields {
|
||||||
|
if fields[i] != userattributedefinition.FieldID {
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ps := _q.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if limit := _q.ctx.Limit; limit != nil {
|
||||||
|
_spec.Limit = *limit
|
||||||
|
}
|
||||||
|
if offset := _q.ctx.Offset; offset != nil {
|
||||||
|
_spec.Offset = *offset
|
||||||
|
}
|
||||||
|
if ps := _q.order; len(ps) > 0 {
|
||||||
|
_spec.Order = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return _spec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||||
|
builder := sql.Dialect(_q.driver.Dialect())
|
||||||
|
t1 := builder.Table(userattributedefinition.Table)
|
||||||
|
columns := _q.ctx.Fields
|
||||||
|
if len(columns) == 0 {
|
||||||
|
columns = userattributedefinition.Columns
|
||||||
|
}
|
||||||
|
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
||||||
|
if _q.sql != nil {
|
||||||
|
selector = _q.sql
|
||||||
|
selector.Select(selector.Columns(columns...)...)
|
||||||
|
}
|
||||||
|
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||||
|
selector.Distinct()
|
||||||
|
}
|
||||||
|
for _, p := range _q.predicates {
|
||||||
|
p(selector)
|
||||||
|
}
|
||||||
|
for _, p := range _q.order {
|
||||||
|
p(selector)
|
||||||
|
}
|
||||||
|
if offset := _q.ctx.Offset; offset != nil {
|
||||||
|
// limit is mandatory for offset clause. We start
|
||||||
|
// with default value, and override it below if needed.
|
||||||
|
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||||
|
}
|
||||||
|
if limit := _q.ctx.Limit; limit != nil {
|
||||||
|
selector.Limit(*limit)
|
||||||
|
}
|
||||||
|
return selector
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeDefinitionGroupBy is the group-by builder for UserAttributeDefinition entities.
|
||||||
|
type UserAttributeDefinitionGroupBy struct {
|
||||||
|
selector
|
||||||
|
build *UserAttributeDefinitionQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate adds the given aggregation functions to the group-by query.
|
||||||
|
func (_g *UserAttributeDefinitionGroupBy) Aggregate(fns ...AggregateFunc) *UserAttributeDefinitionGroupBy {
|
||||||
|
_g.fns = append(_g.fns, fns...)
|
||||||
|
return _g
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan applies the selector query and scans the result into the given value.
|
||||||
|
func (_g *UserAttributeDefinitionGroupBy) Scan(ctx context.Context, v any) error {
|
||||||
|
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
|
||||||
|
if err := _g.build.prepareQuery(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return scanWithInterceptors[*UserAttributeDefinitionQuery, *UserAttributeDefinitionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_g *UserAttributeDefinitionGroupBy) sqlScan(ctx context.Context, root *UserAttributeDefinitionQuery, v any) error {
|
||||||
|
selector := root.sqlQuery(ctx).Select()
|
||||||
|
aggregation := make([]string, 0, len(_g.fns))
|
||||||
|
for _, fn := range _g.fns {
|
||||||
|
aggregation = append(aggregation, fn(selector))
|
||||||
|
}
|
||||||
|
if len(selector.SelectedColumns()) == 0 {
|
||||||
|
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
|
||||||
|
for _, f := range *_g.flds {
|
||||||
|
columns = append(columns, selector.C(f))
|
||||||
|
}
|
||||||
|
columns = append(columns, aggregation...)
|
||||||
|
selector.Select(columns...)
|
||||||
|
}
|
||||||
|
selector.GroupBy(selector.Columns(*_g.flds...)...)
|
||||||
|
if err := selector.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rows := &sql.Rows{}
|
||||||
|
query, args := selector.Query()
|
||||||
|
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return sql.ScanSlice(rows, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeDefinitionSelect is the builder for selecting fields of UserAttributeDefinition entities.
|
||||||
|
type UserAttributeDefinitionSelect struct {
|
||||||
|
*UserAttributeDefinitionQuery
|
||||||
|
selector
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate adds the given aggregation functions to the selector query.
|
||||||
|
func (_s *UserAttributeDefinitionSelect) Aggregate(fns ...AggregateFunc) *UserAttributeDefinitionSelect {
|
||||||
|
_s.fns = append(_s.fns, fns...)
|
||||||
|
return _s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan applies the selector query and scans the result into the given value.
|
||||||
|
func (_s *UserAttributeDefinitionSelect) Scan(ctx context.Context, v any) error {
|
||||||
|
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
|
||||||
|
if err := _s.prepareQuery(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return scanWithInterceptors[*UserAttributeDefinitionQuery, *UserAttributeDefinitionSelect](ctx, _s.UserAttributeDefinitionQuery, _s, _s.inters, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_s *UserAttributeDefinitionSelect) sqlScan(ctx context.Context, root *UserAttributeDefinitionQuery, v any) error {
|
||||||
|
selector := root.sqlQuery(ctx)
|
||||||
|
aggregation := make([]string, 0, len(_s.fns))
|
||||||
|
for _, fn := range _s.fns {
|
||||||
|
aggregation = append(aggregation, fn(selector))
|
||||||
|
}
|
||||||
|
switch n := len(*_s.selector.flds); {
|
||||||
|
case n == 0 && len(aggregation) > 0:
|
||||||
|
selector.Select(aggregation...)
|
||||||
|
case n != 0 && len(aggregation) > 0:
|
||||||
|
selector.AppendSelect(aggregation...)
|
||||||
|
}
|
||||||
|
rows := &sql.Rows{}
|
||||||
|
query, args := selector.Query()
|
||||||
|
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return sql.ScanSlice(rows, v)
|
||||||
|
}
|
||||||
846
backend/ent/userattributedefinition_update.go
Normal file
846
backend/ent/userattributedefinition_update.go
Normal file
@@ -0,0 +1,846 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/dialect/sql/sqljson"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeDefinitionUpdate is the builder for updating UserAttributeDefinition entities.
|
||||||
|
type UserAttributeDefinitionUpdate struct {
|
||||||
|
config
|
||||||
|
hooks []Hook
|
||||||
|
mutation *UserAttributeDefinitionMutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the UserAttributeDefinitionUpdate builder.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.Where(ps...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetUpdatedAt(v time.Time) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeletedAt sets the "deleted_at" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetDeletedAt(v time.Time) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetDeletedAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetNillableDeletedAt(v *time.Time) *UserAttributeDefinitionUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetDeletedAt(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearDeletedAt clears the value of the "deleted_at" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) ClearDeletedAt() *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.ClearDeletedAt()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetKey sets the "key" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetKey(v string) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetKey(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableKey sets the "key" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetNillableKey(v *string) *UserAttributeDefinitionUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetKey(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetName sets the "name" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetName(v string) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetName(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableName sets the "name" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetNillableName(v *string) *UserAttributeDefinitionUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetName(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDescription sets the "description" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetDescription(v string) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetDescription(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableDescription sets the "description" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetNillableDescription(v *string) *UserAttributeDefinitionUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetDescription(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetType sets the "type" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetType(v string) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetType(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableType sets the "type" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetNillableType(v *string) *UserAttributeDefinitionUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetType(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOptions sets the "options" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetOptions(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendOptions appends value to the "options" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) AppendOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.AppendOptions(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRequired sets the "required" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetRequired(v bool) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetRequired(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRequired sets the "required" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetNillableRequired(v *bool) *UserAttributeDefinitionUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRequired(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetValidation sets the "validation" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetValidation(v map[string]interface{}) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetValidation(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPlaceholder sets the "placeholder" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetPlaceholder(v string) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetPlaceholder(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillablePlaceholder sets the "placeholder" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetNillablePlaceholder(v *string) *UserAttributeDefinitionUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetPlaceholder(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisplayOrder sets the "display_order" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetDisplayOrder(v int) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.ResetDisplayOrder()
|
||||||
|
_u.mutation.SetDisplayOrder(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableDisplayOrder sets the "display_order" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetNillableDisplayOrder(v *int) *UserAttributeDefinitionUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetDisplayOrder(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDisplayOrder adds value to the "display_order" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) AddDisplayOrder(v int) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.AddDisplayOrder(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnabled sets the "enabled" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetEnabled(v bool) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.SetEnabled(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableEnabled sets the "enabled" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SetNillableEnabled(v *bool) *UserAttributeDefinitionUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetEnabled(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddValueIDs adds the "values" edge to the UserAttributeValue entity by IDs.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) AddValueIDs(ids ...int64) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.AddValueIDs(ids...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddValues adds the "values" edges to the UserAttributeValue entity.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) AddValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdate {
|
||||||
|
ids := make([]int64, len(v))
|
||||||
|
for i := range v {
|
||||||
|
ids[i] = v[i].ID
|
||||||
|
}
|
||||||
|
return _u.AddValueIDs(ids...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutation returns the UserAttributeDefinitionMutation object of the builder.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) Mutation() *UserAttributeDefinitionMutation {
|
||||||
|
return _u.mutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearValues clears all "values" edges to the UserAttributeValue entity.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) ClearValues() *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.ClearValues()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveValueIDs removes the "values" edge to UserAttributeValue entities by IDs.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) RemoveValueIDs(ids ...int64) *UserAttributeDefinitionUpdate {
|
||||||
|
_u.mutation.RemoveValueIDs(ids...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveValues removes "values" edges to UserAttributeValue entities.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) RemoveValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdate {
|
||||||
|
ids := make([]int64, len(v))
|
||||||
|
for i := range v {
|
||||||
|
ids[i] = v[i].ID
|
||||||
|
}
|
||||||
|
return _u.RemoveValueIDs(ids...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) Save(ctx context.Context) (int, error) {
|
||||||
|
if err := _u.defaults(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveX is like Save, but panics if an error occurs.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) SaveX(ctx context.Context) int {
|
||||||
|
affected, err := _u.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return affected
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the query.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) Exec(ctx context.Context) error {
|
||||||
|
_, err := _u.Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) ExecX(ctx context.Context) {
|
||||||
|
if err := _u.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaults sets the default values of the builder before save.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) defaults() error {
|
||||||
|
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||||
|
if userattributedefinition.UpdateDefaultUpdatedAt == nil {
|
||||||
|
return fmt.Errorf("ent: uninitialized userattributedefinition.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
|
||||||
|
}
|
||||||
|
v := userattributedefinition.UpdateDefaultUpdatedAt()
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check runs all checks and user-defined validators on the builder.
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) check() error {
|
||||||
|
if v, ok := _u.mutation.Key(); ok {
|
||||||
|
if err := userattributedefinition.KeyValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.key": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.Name(); ok {
|
||||||
|
if err := userattributedefinition.NameValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.name": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.GetType(); ok {
|
||||||
|
if err := userattributedefinition.TypeValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "type", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.type": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.Placeholder(); ok {
|
||||||
|
if err := userattributedefinition.PlaceholderValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "placeholder", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.placeholder": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_u *UserAttributeDefinitionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||||
|
if err := _u.check(); err != nil {
|
||||||
|
return _node, err
|
||||||
|
}
|
||||||
|
_spec := sqlgraph.NewUpdateSpec(userattributedefinition.Table, userattributedefinition.Columns, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
|
||||||
|
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldUpdatedAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.DeletedAt(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldDeletedAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.DeletedAtCleared() {
|
||||||
|
_spec.ClearField(userattributedefinition.FieldDeletedAt, field.TypeTime)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Key(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldKey, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Name(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldName, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Description(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldDescription, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.GetType(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldType, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Options(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldOptions, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AppendedOptions(); ok {
|
||||||
|
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||||
|
sqljson.Append(u, userattributedefinition.FieldOptions, value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Required(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldRequired, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Validation(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldValidation, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Placeholder(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldPlaceholder, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.DisplayOrder(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedDisplayOrder(); ok {
|
||||||
|
_spec.AddField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Enabled(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldEnabled, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ValuesCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: userattributedefinition.ValuesTable,
|
||||||
|
Columns: []string{userattributedefinition.ValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.RemovedValuesIDs(); len(nodes) > 0 && !_u.mutation.ValuesCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: userattributedefinition.ValuesTable,
|
||||||
|
Columns: []string{userattributedefinition.ValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.ValuesIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: userattributedefinition.ValuesTable,
|
||||||
|
Columns: []string{userattributedefinition.ValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||||
|
}
|
||||||
|
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||||
|
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||||
|
err = &NotFoundError{userattributedefinition.Label}
|
||||||
|
} else if sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_u.mutation.done = true
|
||||||
|
return _node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeDefinitionUpdateOne is the builder for updating a single UserAttributeDefinition entity.
|
||||||
|
type UserAttributeDefinitionUpdateOne struct {
|
||||||
|
config
|
||||||
|
fields []string
|
||||||
|
hooks []Hook
|
||||||
|
mutation *UserAttributeDefinitionMutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetUpdatedAt(v time.Time) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeletedAt sets the "deleted_at" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetDeletedAt(v time.Time) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetDeletedAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetNillableDeletedAt(v *time.Time) *UserAttributeDefinitionUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetDeletedAt(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearDeletedAt clears the value of the "deleted_at" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) ClearDeletedAt() *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.ClearDeletedAt()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetKey sets the "key" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetKey(v string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetKey(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableKey sets the "key" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetNillableKey(v *string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetKey(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetName sets the "name" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetName(v string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetName(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableName sets the "name" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetNillableName(v *string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetName(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDescription sets the "description" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetDescription(v string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetDescription(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableDescription sets the "description" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetNillableDescription(v *string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetDescription(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetType sets the "type" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetType(v string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetType(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableType sets the "type" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetNillableType(v *string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetType(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOptions sets the "options" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetOptions(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendOptions appends value to the "options" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) AppendOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.AppendOptions(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRequired sets the "required" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetRequired(v bool) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetRequired(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRequired sets the "required" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetNillableRequired(v *bool) *UserAttributeDefinitionUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRequired(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetValidation sets the "validation" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetValidation(v map[string]interface{}) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetValidation(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPlaceholder sets the "placeholder" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetPlaceholder(v string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetPlaceholder(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillablePlaceholder sets the "placeholder" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetNillablePlaceholder(v *string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetPlaceholder(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisplayOrder sets the "display_order" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetDisplayOrder(v int) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.ResetDisplayOrder()
|
||||||
|
_u.mutation.SetDisplayOrder(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableDisplayOrder sets the "display_order" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetNillableDisplayOrder(v *int) *UserAttributeDefinitionUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetDisplayOrder(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDisplayOrder adds value to the "display_order" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) AddDisplayOrder(v int) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.AddDisplayOrder(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnabled sets the "enabled" field.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetEnabled(v bool) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.SetEnabled(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableEnabled sets the "enabled" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SetNillableEnabled(v *bool) *UserAttributeDefinitionUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetEnabled(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddValueIDs adds the "values" edge to the UserAttributeValue entity by IDs.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) AddValueIDs(ids ...int64) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.AddValueIDs(ids...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddValues adds the "values" edges to the UserAttributeValue entity.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) AddValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdateOne {
|
||||||
|
ids := make([]int64, len(v))
|
||||||
|
for i := range v {
|
||||||
|
ids[i] = v[i].ID
|
||||||
|
}
|
||||||
|
return _u.AddValueIDs(ids...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutation returns the UserAttributeDefinitionMutation object of the builder.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) Mutation() *UserAttributeDefinitionMutation {
|
||||||
|
return _u.mutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearValues clears all "values" edges to the UserAttributeValue entity.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) ClearValues() *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.ClearValues()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveValueIDs removes the "values" edge to UserAttributeValue entities by IDs.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) RemoveValueIDs(ids ...int64) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.RemoveValueIDs(ids...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveValues removes "values" edges to UserAttributeValue entities.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) RemoveValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdateOne {
|
||||||
|
ids := make([]int64, len(v))
|
||||||
|
for i := range v {
|
||||||
|
ids[i] = v[i].ID
|
||||||
|
}
|
||||||
|
return _u.RemoveValueIDs(ids...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the UserAttributeDefinitionUpdate builder.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.mutation.Where(ps...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||||
|
// The default is selecting all fields defined in the entity schema.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) Select(field string, fields ...string) *UserAttributeDefinitionUpdateOne {
|
||||||
|
_u.fields = append([]string{field}, fields...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save executes the query and returns the updated UserAttributeDefinition entity.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) Save(ctx context.Context) (*UserAttributeDefinition, error) {
|
||||||
|
if err := _u.defaults(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveX is like Save, but panics if an error occurs.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) SaveX(ctx context.Context) *UserAttributeDefinition {
|
||||||
|
node, err := _u.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the query on the entity.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) Exec(ctx context.Context) error {
|
||||||
|
_, err := _u.Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) ExecX(ctx context.Context) {
|
||||||
|
if err := _u.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaults sets the default values of the builder before save.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) defaults() error {
|
||||||
|
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||||
|
if userattributedefinition.UpdateDefaultUpdatedAt == nil {
|
||||||
|
return fmt.Errorf("ent: uninitialized userattributedefinition.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
|
||||||
|
}
|
||||||
|
v := userattributedefinition.UpdateDefaultUpdatedAt()
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check runs all checks and user-defined validators on the builder.
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) check() error {
|
||||||
|
if v, ok := _u.mutation.Key(); ok {
|
||||||
|
if err := userattributedefinition.KeyValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.key": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.Name(); ok {
|
||||||
|
if err := userattributedefinition.NameValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.name": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.GetType(); ok {
|
||||||
|
if err := userattributedefinition.TypeValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "type", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.type": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.Placeholder(); ok {
|
||||||
|
if err := userattributedefinition.PlaceholderValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "placeholder", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.placeholder": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_u *UserAttributeDefinitionUpdateOne) sqlSave(ctx context.Context) (_node *UserAttributeDefinition, err error) {
|
||||||
|
if err := _u.check(); err != nil {
|
||||||
|
return _node, err
|
||||||
|
}
|
||||||
|
_spec := sqlgraph.NewUpdateSpec(userattributedefinition.Table, userattributedefinition.Columns, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
|
||||||
|
id, ok := _u.mutation.ID()
|
||||||
|
if !ok {
|
||||||
|
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserAttributeDefinition.id" for update`)}
|
||||||
|
}
|
||||||
|
_spec.Node.ID.Value = id
|
||||||
|
if fields := _u.fields; len(fields) > 0 {
|
||||||
|
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, userattributedefinition.FieldID)
|
||||||
|
for _, f := range fields {
|
||||||
|
if !userattributedefinition.ValidColumn(f) {
|
||||||
|
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||||
|
}
|
||||||
|
if f != userattributedefinition.FieldID {
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldUpdatedAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.DeletedAt(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldDeletedAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.DeletedAtCleared() {
|
||||||
|
_spec.ClearField(userattributedefinition.FieldDeletedAt, field.TypeTime)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Key(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldKey, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Name(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldName, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Description(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldDescription, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.GetType(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldType, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Options(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldOptions, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AppendedOptions(); ok {
|
||||||
|
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||||
|
sqljson.Append(u, userattributedefinition.FieldOptions, value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Required(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldRequired, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Validation(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldValidation, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Placeholder(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldPlaceholder, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.DisplayOrder(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedDisplayOrder(); ok {
|
||||||
|
_spec.AddField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Enabled(); ok {
|
||||||
|
_spec.SetField(userattributedefinition.FieldEnabled, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ValuesCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: userattributedefinition.ValuesTable,
|
||||||
|
Columns: []string{userattributedefinition.ValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.RemovedValuesIDs(); len(nodes) > 0 && !_u.mutation.ValuesCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: userattributedefinition.ValuesTable,
|
||||||
|
Columns: []string{userattributedefinition.ValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.ValuesIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.O2M,
|
||||||
|
Inverse: false,
|
||||||
|
Table: userattributedefinition.ValuesTable,
|
||||||
|
Columns: []string{userattributedefinition.ValuesColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||||
|
}
|
||||||
|
_node = &UserAttributeDefinition{config: _u.config}
|
||||||
|
_spec.Assign = _node.assignValues
|
||||||
|
_spec.ScanValues = _node.scanValues
|
||||||
|
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
|
||||||
|
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||||
|
err = &NotFoundError{userattributedefinition.Label}
|
||||||
|
} else if sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_u.mutation.done = true
|
||||||
|
return _node, nil
|
||||||
|
}
|
||||||
198
backend/ent/userattributevalue.go
Normal file
198
backend/ent/userattributevalue.go
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent"
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeValue is the model entity for the UserAttributeValue schema.
|
||||||
|
type UserAttributeValue struct {
|
||||||
|
config `json:"-"`
|
||||||
|
// ID of the ent.
|
||||||
|
ID int64 `json:"id,omitempty"`
|
||||||
|
// CreatedAt holds the value of the "created_at" field.
|
||||||
|
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
|
// UpdatedAt holds the value of the "updated_at" field.
|
||||||
|
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||||
|
// UserID holds the value of the "user_id" field.
|
||||||
|
UserID int64 `json:"user_id,omitempty"`
|
||||||
|
// AttributeID holds the value of the "attribute_id" field.
|
||||||
|
AttributeID int64 `json:"attribute_id,omitempty"`
|
||||||
|
// Value holds the value of the "value" field.
|
||||||
|
Value string `json:"value,omitempty"`
|
||||||
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
|
// The values are being populated by the UserAttributeValueQuery when eager-loading is set.
|
||||||
|
Edges UserAttributeValueEdges `json:"edges"`
|
||||||
|
selectValues sql.SelectValues
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueEdges holds the relations/edges for other nodes in the graph.
|
||||||
|
type UserAttributeValueEdges struct {
|
||||||
|
// User holds the value of the user edge.
|
||||||
|
User *User `json:"user,omitempty"`
|
||||||
|
// Definition holds the value of the definition edge.
|
||||||
|
Definition *UserAttributeDefinition `json:"definition,omitempty"`
|
||||||
|
// loadedTypes holds the information for reporting if a
|
||||||
|
// type was loaded (or requested) in eager-loading or not.
|
||||||
|
loadedTypes [2]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserOrErr returns the User value or an error if the edge
|
||||||
|
// was not loaded in eager-loading, or loaded but was not found.
|
||||||
|
func (e UserAttributeValueEdges) UserOrErr() (*User, error) {
|
||||||
|
if e.User != nil {
|
||||||
|
return e.User, nil
|
||||||
|
} else if e.loadedTypes[0] {
|
||||||
|
return nil, &NotFoundError{label: user.Label}
|
||||||
|
}
|
||||||
|
return nil, &NotLoadedError{edge: "user"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefinitionOrErr returns the Definition value or an error if the edge
|
||||||
|
// was not loaded in eager-loading, or loaded but was not found.
|
||||||
|
func (e UserAttributeValueEdges) DefinitionOrErr() (*UserAttributeDefinition, error) {
|
||||||
|
if e.Definition != nil {
|
||||||
|
return e.Definition, nil
|
||||||
|
} else if e.loadedTypes[1] {
|
||||||
|
return nil, &NotFoundError{label: userattributedefinition.Label}
|
||||||
|
}
|
||||||
|
return nil, &NotLoadedError{edge: "definition"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanValues returns the types for scanning values from sql.Rows.
|
||||||
|
func (*UserAttributeValue) scanValues(columns []string) ([]any, error) {
|
||||||
|
values := make([]any, len(columns))
|
||||||
|
for i := range columns {
|
||||||
|
switch columns[i] {
|
||||||
|
case userattributevalue.FieldID, userattributevalue.FieldUserID, userattributevalue.FieldAttributeID:
|
||||||
|
values[i] = new(sql.NullInt64)
|
||||||
|
case userattributevalue.FieldValue:
|
||||||
|
values[i] = new(sql.NullString)
|
||||||
|
case userattributevalue.FieldCreatedAt, userattributevalue.FieldUpdatedAt:
|
||||||
|
values[i] = new(sql.NullTime)
|
||||||
|
default:
|
||||||
|
values[i] = new(sql.UnknownType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||||
|
// to the UserAttributeValue fields.
|
||||||
|
func (_m *UserAttributeValue) assignValues(columns []string, values []any) error {
|
||||||
|
if m, n := len(values), len(columns); m < n {
|
||||||
|
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
|
||||||
|
}
|
||||||
|
for i := range columns {
|
||||||
|
switch columns[i] {
|
||||||
|
case userattributevalue.FieldID:
|
||||||
|
value, ok := values[i].(*sql.NullInt64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field id", value)
|
||||||
|
}
|
||||||
|
_m.ID = int64(value.Int64)
|
||||||
|
case userattributevalue.FieldCreatedAt:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.CreatedAt = value.Time
|
||||||
|
}
|
||||||
|
case userattributevalue.FieldUpdatedAt:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.UpdatedAt = value.Time
|
||||||
|
}
|
||||||
|
case userattributevalue.FieldUserID:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field user_id", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.UserID = value.Int64
|
||||||
|
}
|
||||||
|
case userattributevalue.FieldAttributeID:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field attribute_id", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.AttributeID = value.Int64
|
||||||
|
}
|
||||||
|
case userattributevalue.FieldValue:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field value", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Value = value.String
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetValue returns the ent.Value that was dynamically selected and assigned to the UserAttributeValue.
|
||||||
|
// This includes values selected through modifiers, order, etc.
|
||||||
|
func (_m *UserAttributeValue) GetValue(name string) (ent.Value, error) {
|
||||||
|
return _m.selectValues.Get(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryUser queries the "user" edge of the UserAttributeValue entity.
|
||||||
|
func (_m *UserAttributeValue) QueryUser() *UserQuery {
|
||||||
|
return NewUserAttributeValueClient(_m.config).QueryUser(_m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryDefinition queries the "definition" edge of the UserAttributeValue entity.
|
||||||
|
func (_m *UserAttributeValue) QueryDefinition() *UserAttributeDefinitionQuery {
|
||||||
|
return NewUserAttributeValueClient(_m.config).QueryDefinition(_m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update returns a builder for updating this UserAttributeValue.
|
||||||
|
// Note that you need to call UserAttributeValue.Unwrap() before calling this method if this UserAttributeValue
|
||||||
|
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||||
|
func (_m *UserAttributeValue) Update() *UserAttributeValueUpdateOne {
|
||||||
|
return NewUserAttributeValueClient(_m.config).UpdateOne(_m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap unwraps the UserAttributeValue entity that was returned from a transaction after it was closed,
|
||||||
|
// so that all future queries will be executed through the driver which created the transaction.
|
||||||
|
func (_m *UserAttributeValue) Unwrap() *UserAttributeValue {
|
||||||
|
_tx, ok := _m.config.driver.(*txDriver)
|
||||||
|
if !ok {
|
||||||
|
panic("ent: UserAttributeValue is not a transactional entity")
|
||||||
|
}
|
||||||
|
_m.config.driver = _tx.drv
|
||||||
|
return _m
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements the fmt.Stringer.
|
||||||
|
func (_m *UserAttributeValue) String() string {
|
||||||
|
var builder strings.Builder
|
||||||
|
builder.WriteString("UserAttributeValue(")
|
||||||
|
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
|
||||||
|
builder.WriteString("created_at=")
|
||||||
|
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("updated_at=")
|
||||||
|
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("user_id=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("attribute_id=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.AttributeID))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("value=")
|
||||||
|
builder.WriteString(_m.Value)
|
||||||
|
builder.WriteByte(')')
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValues is a parsable slice of UserAttributeValue.
|
||||||
|
type UserAttributeValues []*UserAttributeValue
|
||||||
139
backend/ent/userattributevalue/userattributevalue.go
Normal file
139
backend/ent/userattributevalue/userattributevalue.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package userattributevalue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Label holds the string label denoting the userattributevalue type in the database.
|
||||||
|
Label = "user_attribute_value"
|
||||||
|
// FieldID holds the string denoting the id field in the database.
|
||||||
|
FieldID = "id"
|
||||||
|
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||||
|
FieldCreatedAt = "created_at"
|
||||||
|
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
||||||
|
FieldUpdatedAt = "updated_at"
|
||||||
|
// FieldUserID holds the string denoting the user_id field in the database.
|
||||||
|
FieldUserID = "user_id"
|
||||||
|
// FieldAttributeID holds the string denoting the attribute_id field in the database.
|
||||||
|
FieldAttributeID = "attribute_id"
|
||||||
|
// FieldValue holds the string denoting the value field in the database.
|
||||||
|
FieldValue = "value"
|
||||||
|
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||||
|
EdgeUser = "user"
|
||||||
|
// EdgeDefinition holds the string denoting the definition edge name in mutations.
|
||||||
|
EdgeDefinition = "definition"
|
||||||
|
// Table holds the table name of the userattributevalue in the database.
|
||||||
|
Table = "user_attribute_values"
|
||||||
|
// UserTable is the table that holds the user relation/edge.
|
||||||
|
UserTable = "user_attribute_values"
|
||||||
|
// UserInverseTable is the table name for the User entity.
|
||||||
|
// It exists in this package in order to avoid circular dependency with the "user" package.
|
||||||
|
UserInverseTable = "users"
|
||||||
|
// UserColumn is the table column denoting the user relation/edge.
|
||||||
|
UserColumn = "user_id"
|
||||||
|
// DefinitionTable is the table that holds the definition relation/edge.
|
||||||
|
DefinitionTable = "user_attribute_values"
|
||||||
|
// DefinitionInverseTable is the table name for the UserAttributeDefinition entity.
|
||||||
|
// It exists in this package in order to avoid circular dependency with the "userattributedefinition" package.
|
||||||
|
DefinitionInverseTable = "user_attribute_definitions"
|
||||||
|
// DefinitionColumn is the table column denoting the definition relation/edge.
|
||||||
|
DefinitionColumn = "attribute_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Columns holds all SQL columns for userattributevalue fields.
|
||||||
|
var Columns = []string{
|
||||||
|
FieldID,
|
||||||
|
FieldCreatedAt,
|
||||||
|
FieldUpdatedAt,
|
||||||
|
FieldUserID,
|
||||||
|
FieldAttributeID,
|
||||||
|
FieldValue,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||||
|
func ValidColumn(column string) bool {
|
||||||
|
for i := range Columns {
|
||||||
|
if column == Columns[i] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||||
|
DefaultCreatedAt func() time.Time
|
||||||
|
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||||
|
DefaultUpdatedAt func() time.Time
|
||||||
|
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
|
||||||
|
UpdateDefaultUpdatedAt func() time.Time
|
||||||
|
// DefaultValue holds the default value on creation for the "value" field.
|
||||||
|
DefaultValue string
|
||||||
|
)
|
||||||
|
|
||||||
|
// OrderOption defines the ordering options for the UserAttributeValue queries.
|
||||||
|
type OrderOption func(*sql.Selector)
|
||||||
|
|
||||||
|
// ByID orders the results by the id field.
|
||||||
|
func ByID(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldID, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByCreatedAt orders the results by the created_at field.
|
||||||
|
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByUpdatedAt orders the results by the updated_at field.
|
||||||
|
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByUserID orders the results by the user_id field.
|
||||||
|
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldUserID, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByAttributeID orders the results by the attribute_id field.
|
||||||
|
func ByAttributeID(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldAttributeID, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByValue orders the results by the value field.
|
||||||
|
func ByValue(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldValue, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByUserField orders the results by user field.
|
||||||
|
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return func(s *sql.Selector) {
|
||||||
|
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByDefinitionField orders the results by definition field.
|
||||||
|
func ByDefinitionField(field string, opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return func(s *sql.Selector) {
|
||||||
|
sqlgraph.OrderByNeighborTerms(s, newDefinitionStep(), sql.OrderByField(field, opts...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func newUserStep() *sqlgraph.Step {
|
||||||
|
return sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(Table, FieldID),
|
||||||
|
sqlgraph.To(UserInverseTable, FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
func newDefinitionStep() *sqlgraph.Step {
|
||||||
|
return sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(Table, FieldID),
|
||||||
|
sqlgraph.To(DefinitionInverseTable, FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.M2O, true, DefinitionTable, DefinitionColumn),
|
||||||
|
)
|
||||||
|
}
|
||||||
327
backend/ent/userattributevalue/where.go
Normal file
327
backend/ent/userattributevalue/where.go
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package userattributevalue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ID filters vertices based on their ID field.
|
||||||
|
func ID(id int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDEQ applies the EQ predicate on the ID field.
|
||||||
|
func IDEQ(id int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDNEQ applies the NEQ predicate on the ID field.
|
||||||
|
func IDNEQ(id int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNEQ(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDIn applies the In predicate on the ID field.
|
||||||
|
func IDIn(ids ...int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldIn(FieldID, ids...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDNotIn applies the NotIn predicate on the ID field.
|
||||||
|
func IDNotIn(ids ...int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNotIn(FieldID, ids...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDGT applies the GT predicate on the ID field.
|
||||||
|
func IDGT(id int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldGT(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDGTE applies the GTE predicate on the ID field.
|
||||||
|
func IDGTE(id int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldGTE(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDLT applies the LT predicate on the ID field.
|
||||||
|
func IDLT(id int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldLT(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDLTE applies the LTE predicate on the ID field.
|
||||||
|
func IDLTE(id int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldLTE(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||||
|
func CreatedAt(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
|
||||||
|
func UpdatedAt(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
|
||||||
|
func UserID(v int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldUserID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttributeID applies equality check predicate on the "attribute_id" field. It's identical to AttributeIDEQ.
|
||||||
|
func AttributeID(v int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldAttributeID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value applies equality check predicate on the "value" field. It's identical to ValueEQ.
|
||||||
|
func Value(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
|
func CreatedAtEQ(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
|
||||||
|
func CreatedAtNEQ(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNEQ(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtIn applies the In predicate on the "created_at" field.
|
||||||
|
func CreatedAtIn(vs ...time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldIn(FieldCreatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
|
||||||
|
func CreatedAtNotIn(vs ...time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNotIn(FieldCreatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtGT applies the GT predicate on the "created_at" field.
|
||||||
|
func CreatedAtGT(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldGT(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
|
||||||
|
func CreatedAtGTE(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldGTE(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtLT applies the LT predicate on the "created_at" field.
|
||||||
|
func CreatedAtLT(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldLT(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
|
||||||
|
func CreatedAtLTE(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldLTE(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtEQ(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtNEQ(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNEQ(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtIn applies the In predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtIn(vs ...time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldIn(FieldUpdatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtNotIn(vs ...time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNotIn(FieldUpdatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtGT(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldGT(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtGTE(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldGTE(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtLT(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldLT(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtLTE(v time.Time) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldLTE(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserIDEQ applies the EQ predicate on the "user_id" field.
|
||||||
|
func UserIDEQ(v int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldUserID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
|
||||||
|
func UserIDNEQ(v int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNEQ(FieldUserID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserIDIn applies the In predicate on the "user_id" field.
|
||||||
|
func UserIDIn(vs ...int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldIn(FieldUserID, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
|
||||||
|
func UserIDNotIn(vs ...int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNotIn(FieldUserID, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttributeIDEQ applies the EQ predicate on the "attribute_id" field.
|
||||||
|
func AttributeIDEQ(v int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldAttributeID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttributeIDNEQ applies the NEQ predicate on the "attribute_id" field.
|
||||||
|
func AttributeIDNEQ(v int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNEQ(FieldAttributeID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttributeIDIn applies the In predicate on the "attribute_id" field.
|
||||||
|
func AttributeIDIn(vs ...int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldIn(FieldAttributeID, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttributeIDNotIn applies the NotIn predicate on the "attribute_id" field.
|
||||||
|
func AttributeIDNotIn(vs ...int64) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNotIn(FieldAttributeID, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueEQ applies the EQ predicate on the "value" field.
|
||||||
|
func ValueEQ(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEQ(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueNEQ applies the NEQ predicate on the "value" field.
|
||||||
|
func ValueNEQ(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNEQ(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueIn applies the In predicate on the "value" field.
|
||||||
|
func ValueIn(vs ...string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldIn(FieldValue, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueNotIn applies the NotIn predicate on the "value" field.
|
||||||
|
func ValueNotIn(vs ...string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldNotIn(FieldValue, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueGT applies the GT predicate on the "value" field.
|
||||||
|
func ValueGT(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldGT(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueGTE applies the GTE predicate on the "value" field.
|
||||||
|
func ValueGTE(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldGTE(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueLT applies the LT predicate on the "value" field.
|
||||||
|
func ValueLT(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldLT(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueLTE applies the LTE predicate on the "value" field.
|
||||||
|
func ValueLTE(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldLTE(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueContains applies the Contains predicate on the "value" field.
|
||||||
|
func ValueContains(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldContains(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueHasPrefix applies the HasPrefix predicate on the "value" field.
|
||||||
|
func ValueHasPrefix(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldHasPrefix(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueHasSuffix applies the HasSuffix predicate on the "value" field.
|
||||||
|
func ValueHasSuffix(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldHasSuffix(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueEqualFold applies the EqualFold predicate on the "value" field.
|
||||||
|
func ValueEqualFold(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldEqualFold(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueContainsFold applies the ContainsFold predicate on the "value" field.
|
||||||
|
func ValueContainsFold(v string) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.FieldContainsFold(FieldValue, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasUser applies the HasEdge predicate on the "user" edge.
|
||||||
|
func HasUser() predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(func(s *sql.Selector) {
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(Table, FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
|
||||||
|
)
|
||||||
|
sqlgraph.HasNeighbors(s, step)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
|
||||||
|
func HasUserWith(preds ...predicate.User) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(func(s *sql.Selector) {
|
||||||
|
step := newUserStep()
|
||||||
|
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||||
|
for _, p := range preds {
|
||||||
|
p(s)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasDefinition applies the HasEdge predicate on the "definition" edge.
|
||||||
|
func HasDefinition() predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(func(s *sql.Selector) {
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(Table, FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.M2O, true, DefinitionTable, DefinitionColumn),
|
||||||
|
)
|
||||||
|
sqlgraph.HasNeighbors(s, step)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasDefinitionWith applies the HasEdge predicate on the "definition" edge with a given conditions (other predicates).
|
||||||
|
func HasDefinitionWith(preds ...predicate.UserAttributeDefinition) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(func(s *sql.Selector) {
|
||||||
|
step := newDefinitionStep()
|
||||||
|
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||||
|
for _, p := range preds {
|
||||||
|
p(s)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// And groups predicates with the AND operator between them.
|
||||||
|
func And(predicates ...predicate.UserAttributeValue) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.AndPredicates(predicates...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Or groups predicates with the OR operator between them.
|
||||||
|
func Or(predicates ...predicate.UserAttributeValue) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.OrPredicates(predicates...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not applies the not operator on the given predicate.
|
||||||
|
func Not(p predicate.UserAttributeValue) predicate.UserAttributeValue {
|
||||||
|
return predicate.UserAttributeValue(sql.NotPredicates(p))
|
||||||
|
}
|
||||||
731
backend/ent/userattributevalue_create.go
Normal file
731
backend/ent/userattributevalue_create.go
Normal file
@@ -0,0 +1,731 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeValueCreate is the builder for creating a UserAttributeValue entity.
|
||||||
|
type UserAttributeValueCreate struct {
|
||||||
|
config
|
||||||
|
mutation *UserAttributeValueMutation
|
||||||
|
hooks []Hook
|
||||||
|
conflict []sql.ConflictOption
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCreatedAt sets the "created_at" field.
|
||||||
|
func (_c *UserAttributeValueCreate) SetCreatedAt(v time.Time) *UserAttributeValueCreate {
|
||||||
|
_c.mutation.SetCreatedAt(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
|
||||||
|
func (_c *UserAttributeValueCreate) SetNillableCreatedAt(v *time.Time) *UserAttributeValueCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetCreatedAt(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
|
func (_c *UserAttributeValueCreate) SetUpdatedAt(v time.Time) *UserAttributeValueCreate {
|
||||||
|
_c.mutation.SetUpdatedAt(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
|
||||||
|
func (_c *UserAttributeValueCreate) SetNillableUpdatedAt(v *time.Time) *UserAttributeValueCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetUpdatedAt(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserID sets the "user_id" field.
|
||||||
|
func (_c *UserAttributeValueCreate) SetUserID(v int64) *UserAttributeValueCreate {
|
||||||
|
_c.mutation.SetUserID(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAttributeID sets the "attribute_id" field.
|
||||||
|
func (_c *UserAttributeValueCreate) SetAttributeID(v int64) *UserAttributeValueCreate {
|
||||||
|
_c.mutation.SetAttributeID(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetValue sets the "value" field.
|
||||||
|
func (_c *UserAttributeValueCreate) SetValue(v string) *UserAttributeValueCreate {
|
||||||
|
_c.mutation.SetValue(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableValue sets the "value" field if the given value is not nil.
|
||||||
|
func (_c *UserAttributeValueCreate) SetNillableValue(v *string) *UserAttributeValueCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetValue(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUser sets the "user" edge to the User entity.
|
||||||
|
func (_c *UserAttributeValueCreate) SetUser(v *User) *UserAttributeValueCreate {
|
||||||
|
return _c.SetUserID(v.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefinitionID sets the "definition" edge to the UserAttributeDefinition entity by ID.
|
||||||
|
func (_c *UserAttributeValueCreate) SetDefinitionID(id int64) *UserAttributeValueCreate {
|
||||||
|
_c.mutation.SetDefinitionID(id)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefinition sets the "definition" edge to the UserAttributeDefinition entity.
|
||||||
|
func (_c *UserAttributeValueCreate) SetDefinition(v *UserAttributeDefinition) *UserAttributeValueCreate {
|
||||||
|
return _c.SetDefinitionID(v.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutation returns the UserAttributeValueMutation object of the builder.
|
||||||
|
func (_c *UserAttributeValueCreate) Mutation() *UserAttributeValueMutation {
|
||||||
|
return _c.mutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save creates the UserAttributeValue in the database.
|
||||||
|
func (_c *UserAttributeValueCreate) Save(ctx context.Context) (*UserAttributeValue, error) {
|
||||||
|
_c.defaults()
|
||||||
|
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveX calls Save and panics if Save returns an error.
|
||||||
|
func (_c *UserAttributeValueCreate) SaveX(ctx context.Context) *UserAttributeValue {
|
||||||
|
v, err := _c.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the query.
|
||||||
|
func (_c *UserAttributeValueCreate) Exec(ctx context.Context) error {
|
||||||
|
_, err := _c.Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_c *UserAttributeValueCreate) ExecX(ctx context.Context) {
|
||||||
|
if err := _c.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaults sets the default values of the builder before save.
|
||||||
|
func (_c *UserAttributeValueCreate) defaults() {
|
||||||
|
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||||
|
v := userattributevalue.DefaultCreatedAt()
|
||||||
|
_c.mutation.SetCreatedAt(v)
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.UpdatedAt(); !ok {
|
||||||
|
v := userattributevalue.DefaultUpdatedAt()
|
||||||
|
_c.mutation.SetUpdatedAt(v)
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.Value(); !ok {
|
||||||
|
v := userattributevalue.DefaultValue
|
||||||
|
_c.mutation.SetValue(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check runs all checks and user-defined validators on the builder.
|
||||||
|
func (_c *UserAttributeValueCreate) check() error {
|
||||||
|
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||||
|
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UserAttributeValue.created_at"`)}
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.UpdatedAt(); !ok {
|
||||||
|
return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "UserAttributeValue.updated_at"`)}
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.UserID(); !ok {
|
||||||
|
return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "UserAttributeValue.user_id"`)}
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.AttributeID(); !ok {
|
||||||
|
return &ValidationError{Name: "attribute_id", err: errors.New(`ent: missing required field "UserAttributeValue.attribute_id"`)}
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.Value(); !ok {
|
||||||
|
return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "UserAttributeValue.value"`)}
|
||||||
|
}
|
||||||
|
if len(_c.mutation.UserIDs()) == 0 {
|
||||||
|
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "UserAttributeValue.user"`)}
|
||||||
|
}
|
||||||
|
if len(_c.mutation.DefinitionIDs()) == 0 {
|
||||||
|
return &ValidationError{Name: "definition", err: errors.New(`ent: missing required edge "UserAttributeValue.definition"`)}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *UserAttributeValueCreate) sqlSave(ctx context.Context) (*UserAttributeValue, error) {
|
||||||
|
if err := _c.check(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_node, _spec := _c.createSpec()
|
||||||
|
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
|
||||||
|
if sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
id := _spec.ID.Value.(int64)
|
||||||
|
_node.ID = int64(id)
|
||||||
|
_c.mutation.id = &_node.ID
|
||||||
|
_c.mutation.done = true
|
||||||
|
return _node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *UserAttributeValueCreate) createSpec() (*UserAttributeValue, *sqlgraph.CreateSpec) {
|
||||||
|
var (
|
||||||
|
_node = &UserAttributeValue{config: _c.config}
|
||||||
|
_spec = sqlgraph.NewCreateSpec(userattributevalue.Table, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
|
||||||
|
)
|
||||||
|
_spec.OnConflict = _c.conflict
|
||||||
|
if value, ok := _c.mutation.CreatedAt(); ok {
|
||||||
|
_spec.SetField(userattributevalue.FieldCreatedAt, field.TypeTime, value)
|
||||||
|
_node.CreatedAt = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.UpdatedAt(); ok {
|
||||||
|
_spec.SetField(userattributevalue.FieldUpdatedAt, field.TypeTime, value)
|
||||||
|
_node.UpdatedAt = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.Value(); ok {
|
||||||
|
_spec.SetField(userattributevalue.FieldValue, field.TypeString, value)
|
||||||
|
_node.Value = value
|
||||||
|
}
|
||||||
|
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.M2O,
|
||||||
|
Inverse: true,
|
||||||
|
Table: userattributevalue.UserTable,
|
||||||
|
Columns: []string{userattributevalue.UserColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_node.UserID = nodes[0]
|
||||||
|
_spec.Edges = append(_spec.Edges, edge)
|
||||||
|
}
|
||||||
|
if nodes := _c.mutation.DefinitionIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.M2O,
|
||||||
|
Inverse: true,
|
||||||
|
Table: userattributevalue.DefinitionTable,
|
||||||
|
Columns: []string{userattributevalue.DefinitionColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_node.AttributeID = nodes[0]
|
||||||
|
_spec.Edges = append(_spec.Edges, edge)
|
||||||
|
}
|
||||||
|
return _node, _spec
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
|
||||||
|
// of the `INSERT` statement. For example:
|
||||||
|
//
|
||||||
|
// client.UserAttributeValue.Create().
|
||||||
|
// SetCreatedAt(v).
|
||||||
|
// OnConflict(
|
||||||
|
// // Update the row with the new values
|
||||||
|
// // the was proposed for insertion.
|
||||||
|
// sql.ResolveWithNewValues(),
|
||||||
|
// ).
|
||||||
|
// // Override some of the fields with custom
|
||||||
|
// // update values.
|
||||||
|
// Update(func(u *ent.UserAttributeValueUpsert) {
|
||||||
|
// SetCreatedAt(v+v).
|
||||||
|
// }).
|
||||||
|
// Exec(ctx)
|
||||||
|
func (_c *UserAttributeValueCreate) OnConflict(opts ...sql.ConflictOption) *UserAttributeValueUpsertOne {
|
||||||
|
_c.conflict = opts
|
||||||
|
return &UserAttributeValueUpsertOne{
|
||||||
|
create: _c,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConflictColumns calls `OnConflict` and configures the columns
|
||||||
|
// as conflict target. Using this option is equivalent to using:
|
||||||
|
//
|
||||||
|
// client.UserAttributeValue.Create().
|
||||||
|
// OnConflict(sql.ConflictColumns(columns...)).
|
||||||
|
// Exec(ctx)
|
||||||
|
func (_c *UserAttributeValueCreate) OnConflictColumns(columns ...string) *UserAttributeValueUpsertOne {
|
||||||
|
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
|
||||||
|
return &UserAttributeValueUpsertOne{
|
||||||
|
create: _c,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
// UserAttributeValueUpsertOne is the builder for "upsert"-ing
|
||||||
|
// one UserAttributeValue node.
|
||||||
|
UserAttributeValueUpsertOne struct {
|
||||||
|
create *UserAttributeValueCreate
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueUpsert is the "OnConflict" setter.
|
||||||
|
UserAttributeValueUpsert struct {
|
||||||
|
*sql.UpdateSet
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
|
func (u *UserAttributeValueUpsert) SetUpdatedAt(v time.Time) *UserAttributeValueUpsert {
|
||||||
|
u.Set(userattributevalue.FieldUpdatedAt, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsert) UpdateUpdatedAt() *UserAttributeValueUpsert {
|
||||||
|
u.SetExcluded(userattributevalue.FieldUpdatedAt)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserID sets the "user_id" field.
|
||||||
|
func (u *UserAttributeValueUpsert) SetUserID(v int64) *UserAttributeValueUpsert {
|
||||||
|
u.Set(userattributevalue.FieldUserID, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserID sets the "user_id" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsert) UpdateUserID() *UserAttributeValueUpsert {
|
||||||
|
u.SetExcluded(userattributevalue.FieldUserID)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAttributeID sets the "attribute_id" field.
|
||||||
|
func (u *UserAttributeValueUpsert) SetAttributeID(v int64) *UserAttributeValueUpsert {
|
||||||
|
u.Set(userattributevalue.FieldAttributeID, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAttributeID sets the "attribute_id" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsert) UpdateAttributeID() *UserAttributeValueUpsert {
|
||||||
|
u.SetExcluded(userattributevalue.FieldAttributeID)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetValue sets the "value" field.
|
||||||
|
func (u *UserAttributeValueUpsert) SetValue(v string) *UserAttributeValueUpsert {
|
||||||
|
u.Set(userattributevalue.FieldValue, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateValue sets the "value" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsert) UpdateValue() *UserAttributeValueUpsert {
|
||||||
|
u.SetExcluded(userattributevalue.FieldValue)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
|
// Using this option is equivalent to using:
|
||||||
|
//
|
||||||
|
// client.UserAttributeValue.Create().
|
||||||
|
// OnConflict(
|
||||||
|
// sql.ResolveWithNewValues(),
|
||||||
|
// ).
|
||||||
|
// Exec(ctx)
|
||||||
|
func (u *UserAttributeValueUpsertOne) UpdateNewValues() *UserAttributeValueUpsertOne {
|
||||||
|
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
|
||||||
|
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
|
||||||
|
if _, exists := u.create.mutation.CreatedAt(); exists {
|
||||||
|
s.SetIgnore(userattributevalue.FieldCreatedAt)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore sets each column to itself in case of conflict.
|
||||||
|
// Using this option is equivalent to using:
|
||||||
|
//
|
||||||
|
// client.UserAttributeValue.Create().
|
||||||
|
// OnConflict(sql.ResolveWithIgnore()).
|
||||||
|
// Exec(ctx)
|
||||||
|
func (u *UserAttributeValueUpsertOne) Ignore() *UserAttributeValueUpsertOne {
|
||||||
|
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoNothing configures the conflict_action to `DO NOTHING`.
|
||||||
|
// Supported only by SQLite and PostgreSQL.
|
||||||
|
func (u *UserAttributeValueUpsertOne) DoNothing() *UserAttributeValueUpsertOne {
|
||||||
|
u.create.conflict = append(u.create.conflict, sql.DoNothing())
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update allows overriding fields `UPDATE` values. See the UserAttributeValueCreate.OnConflict
|
||||||
|
// documentation for more info.
|
||||||
|
func (u *UserAttributeValueUpsertOne) Update(set func(*UserAttributeValueUpsert)) *UserAttributeValueUpsertOne {
|
||||||
|
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
|
||||||
|
set(&UserAttributeValueUpsert{UpdateSet: update})
|
||||||
|
}))
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
|
func (u *UserAttributeValueUpsertOne) SetUpdatedAt(v time.Time) *UserAttributeValueUpsertOne {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.SetUpdatedAt(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsertOne) UpdateUpdatedAt() *UserAttributeValueUpsertOne {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.UpdateUpdatedAt()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserID sets the "user_id" field.
|
||||||
|
func (u *UserAttributeValueUpsertOne) SetUserID(v int64) *UserAttributeValueUpsertOne {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.SetUserID(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserID sets the "user_id" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsertOne) UpdateUserID() *UserAttributeValueUpsertOne {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.UpdateUserID()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAttributeID sets the "attribute_id" field.
|
||||||
|
func (u *UserAttributeValueUpsertOne) SetAttributeID(v int64) *UserAttributeValueUpsertOne {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.SetAttributeID(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAttributeID sets the "attribute_id" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsertOne) UpdateAttributeID() *UserAttributeValueUpsertOne {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.UpdateAttributeID()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetValue sets the "value" field.
|
||||||
|
func (u *UserAttributeValueUpsertOne) SetValue(v string) *UserAttributeValueUpsertOne {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.SetValue(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateValue sets the "value" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsertOne) UpdateValue() *UserAttributeValueUpsertOne {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.UpdateValue()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the query.
|
||||||
|
func (u *UserAttributeValueUpsertOne) Exec(ctx context.Context) error {
|
||||||
|
if len(u.create.conflict) == 0 {
|
||||||
|
return errors.New("ent: missing options for UserAttributeValueCreate.OnConflict")
|
||||||
|
}
|
||||||
|
return u.create.Exec(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (u *UserAttributeValueUpsertOne) ExecX(ctx context.Context) {
|
||||||
|
if err := u.create.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the UPSERT query and returns the inserted/updated ID.
|
||||||
|
func (u *UserAttributeValueUpsertOne) ID(ctx context.Context) (id int64, err error) {
|
||||||
|
node, err := u.create.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
return node.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDX is like ID, but panics if an error occurs.
|
||||||
|
func (u *UserAttributeValueUpsertOne) IDX(ctx context.Context) int64 {
|
||||||
|
id, err := u.ID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueCreateBulk is the builder for creating many UserAttributeValue entities in bulk.
|
||||||
|
type UserAttributeValueCreateBulk struct {
|
||||||
|
config
|
||||||
|
err error
|
||||||
|
builders []*UserAttributeValueCreate
|
||||||
|
conflict []sql.ConflictOption
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save creates the UserAttributeValue entities in the database.
|
||||||
|
func (_c *UserAttributeValueCreateBulk) Save(ctx context.Context) ([]*UserAttributeValue, error) {
|
||||||
|
if _c.err != nil {
|
||||||
|
return nil, _c.err
|
||||||
|
}
|
||||||
|
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
|
||||||
|
nodes := make([]*UserAttributeValue, len(_c.builders))
|
||||||
|
mutators := make([]Mutator, len(_c.builders))
|
||||||
|
for i := range _c.builders {
|
||||||
|
func(i int, root context.Context) {
|
||||||
|
builder := _c.builders[i]
|
||||||
|
builder.defaults()
|
||||||
|
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
|
||||||
|
mutation, ok := m.(*UserAttributeValueMutation)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected mutation type %T", m)
|
||||||
|
}
|
||||||
|
if err := builder.check(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
builder.mutation = mutation
|
||||||
|
var err error
|
||||||
|
nodes[i], specs[i] = builder.createSpec()
|
||||||
|
if i < len(mutators)-1 {
|
||||||
|
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
|
||||||
|
} else {
|
||||||
|
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
|
||||||
|
spec.OnConflict = _c.conflict
|
||||||
|
// Invoke the actual operation on the latest mutation in the chain.
|
||||||
|
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
|
||||||
|
if sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mutation.id = &nodes[i].ID
|
||||||
|
if specs[i].ID.Value != nil {
|
||||||
|
id := specs[i].ID.Value.(int64)
|
||||||
|
nodes[i].ID = int64(id)
|
||||||
|
}
|
||||||
|
mutation.done = true
|
||||||
|
return nodes[i], nil
|
||||||
|
})
|
||||||
|
for i := len(builder.hooks) - 1; i >= 0; i-- {
|
||||||
|
mut = builder.hooks[i](mut)
|
||||||
|
}
|
||||||
|
mutators[i] = mut
|
||||||
|
}(i, ctx)
|
||||||
|
}
|
||||||
|
if len(mutators) > 0 {
|
||||||
|
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveX is like Save, but panics if an error occurs.
|
||||||
|
func (_c *UserAttributeValueCreateBulk) SaveX(ctx context.Context) []*UserAttributeValue {
|
||||||
|
v, err := _c.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the query.
|
||||||
|
func (_c *UserAttributeValueCreateBulk) Exec(ctx context.Context) error {
|
||||||
|
_, err := _c.Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_c *UserAttributeValueCreateBulk) ExecX(ctx context.Context) {
|
||||||
|
if err := _c.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
|
||||||
|
// of the `INSERT` statement. For example:
|
||||||
|
//
|
||||||
|
// client.UserAttributeValue.CreateBulk(builders...).
|
||||||
|
// OnConflict(
|
||||||
|
// // Update the row with the new values
|
||||||
|
// // the was proposed for insertion.
|
||||||
|
// sql.ResolveWithNewValues(),
|
||||||
|
// ).
|
||||||
|
// // Override some of the fields with custom
|
||||||
|
// // update values.
|
||||||
|
// Update(func(u *ent.UserAttributeValueUpsert) {
|
||||||
|
// SetCreatedAt(v+v).
|
||||||
|
// }).
|
||||||
|
// Exec(ctx)
|
||||||
|
func (_c *UserAttributeValueCreateBulk) OnConflict(opts ...sql.ConflictOption) *UserAttributeValueUpsertBulk {
|
||||||
|
_c.conflict = opts
|
||||||
|
return &UserAttributeValueUpsertBulk{
|
||||||
|
create: _c,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConflictColumns calls `OnConflict` and configures the columns
|
||||||
|
// as conflict target. Using this option is equivalent to using:
|
||||||
|
//
|
||||||
|
// client.UserAttributeValue.Create().
|
||||||
|
// OnConflict(sql.ConflictColumns(columns...)).
|
||||||
|
// Exec(ctx)
|
||||||
|
func (_c *UserAttributeValueCreateBulk) OnConflictColumns(columns ...string) *UserAttributeValueUpsertBulk {
|
||||||
|
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
|
||||||
|
return &UserAttributeValueUpsertBulk{
|
||||||
|
create: _c,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueUpsertBulk is the builder for "upsert"-ing
|
||||||
|
// a bulk of UserAttributeValue nodes.
|
||||||
|
type UserAttributeValueUpsertBulk struct {
|
||||||
|
create *UserAttributeValueCreateBulk
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateNewValues updates the mutable fields using the new values that
|
||||||
|
// were set on create. Using this option is equivalent to using:
|
||||||
|
//
|
||||||
|
// client.UserAttributeValue.Create().
|
||||||
|
// OnConflict(
|
||||||
|
// sql.ResolveWithNewValues(),
|
||||||
|
// ).
|
||||||
|
// Exec(ctx)
|
||||||
|
func (u *UserAttributeValueUpsertBulk) UpdateNewValues() *UserAttributeValueUpsertBulk {
|
||||||
|
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
|
||||||
|
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
|
||||||
|
for _, b := range u.create.builders {
|
||||||
|
if _, exists := b.mutation.CreatedAt(); exists {
|
||||||
|
s.SetIgnore(userattributevalue.FieldCreatedAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore sets each column to itself in case of conflict.
|
||||||
|
// Using this option is equivalent to using:
|
||||||
|
//
|
||||||
|
// client.UserAttributeValue.Create().
|
||||||
|
// OnConflict(sql.ResolveWithIgnore()).
|
||||||
|
// Exec(ctx)
|
||||||
|
func (u *UserAttributeValueUpsertBulk) Ignore() *UserAttributeValueUpsertBulk {
|
||||||
|
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoNothing configures the conflict_action to `DO NOTHING`.
|
||||||
|
// Supported only by SQLite and PostgreSQL.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) DoNothing() *UserAttributeValueUpsertBulk {
|
||||||
|
u.create.conflict = append(u.create.conflict, sql.DoNothing())
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update allows overriding fields `UPDATE` values. See the UserAttributeValueCreateBulk.OnConflict
|
||||||
|
// documentation for more info.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) Update(set func(*UserAttributeValueUpsert)) *UserAttributeValueUpsertBulk {
|
||||||
|
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
|
||||||
|
set(&UserAttributeValueUpsert{UpdateSet: update})
|
||||||
|
}))
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) SetUpdatedAt(v time.Time) *UserAttributeValueUpsertBulk {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.SetUpdatedAt(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) UpdateUpdatedAt() *UserAttributeValueUpsertBulk {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.UpdateUpdatedAt()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserID sets the "user_id" field.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) SetUserID(v int64) *UserAttributeValueUpsertBulk {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.SetUserID(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserID sets the "user_id" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) UpdateUserID() *UserAttributeValueUpsertBulk {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.UpdateUserID()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAttributeID sets the "attribute_id" field.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) SetAttributeID(v int64) *UserAttributeValueUpsertBulk {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.SetAttributeID(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAttributeID sets the "attribute_id" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) UpdateAttributeID() *UserAttributeValueUpsertBulk {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.UpdateAttributeID()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetValue sets the "value" field.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) SetValue(v string) *UserAttributeValueUpsertBulk {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.SetValue(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateValue sets the "value" field to the value that was provided on create.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) UpdateValue() *UserAttributeValueUpsertBulk {
|
||||||
|
return u.Update(func(s *UserAttributeValueUpsert) {
|
||||||
|
s.UpdateValue()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the query.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) Exec(ctx context.Context) error {
|
||||||
|
if u.create.err != nil {
|
||||||
|
return u.create.err
|
||||||
|
}
|
||||||
|
for i, b := range u.create.builders {
|
||||||
|
if len(b.conflict) != 0 {
|
||||||
|
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UserAttributeValueCreateBulk instead", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(u.create.conflict) == 0 {
|
||||||
|
return errors.New("ent: missing options for UserAttributeValueCreateBulk.OnConflict")
|
||||||
|
}
|
||||||
|
return u.create.Exec(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (u *UserAttributeValueUpsertBulk) ExecX(ctx context.Context) {
|
||||||
|
if err := u.create.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
88
backend/ent/userattributevalue_delete.go
Normal file
88
backend/ent/userattributevalue_delete.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeValueDelete is the builder for deleting a UserAttributeValue entity.
|
||||||
|
type UserAttributeValueDelete struct {
|
||||||
|
config
|
||||||
|
hooks []Hook
|
||||||
|
mutation *UserAttributeValueMutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the UserAttributeValueDelete builder.
|
||||||
|
func (_d *UserAttributeValueDelete) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueDelete {
|
||||||
|
_d.mutation.Where(ps...)
|
||||||
|
return _d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||||
|
func (_d *UserAttributeValueDelete) Exec(ctx context.Context) (int, error) {
|
||||||
|
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_d *UserAttributeValueDelete) ExecX(ctx context.Context) int {
|
||||||
|
n, err := _d.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_d *UserAttributeValueDelete) sqlExec(ctx context.Context) (int, error) {
|
||||||
|
_spec := sqlgraph.NewDeleteSpec(userattributevalue.Table, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
|
||||||
|
if ps := _d.mutation.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
|
||||||
|
if err != nil && sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
_d.mutation.done = true
|
||||||
|
return affected, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueDeleteOne is the builder for deleting a single UserAttributeValue entity.
|
||||||
|
type UserAttributeValueDeleteOne struct {
|
||||||
|
_d *UserAttributeValueDelete
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the UserAttributeValueDelete builder.
|
||||||
|
func (_d *UserAttributeValueDeleteOne) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueDeleteOne {
|
||||||
|
_d._d.mutation.Where(ps...)
|
||||||
|
return _d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the deletion query.
|
||||||
|
func (_d *UserAttributeValueDeleteOne) Exec(ctx context.Context) error {
|
||||||
|
n, err := _d._d.Exec(ctx)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return err
|
||||||
|
case n == 0:
|
||||||
|
return &NotFoundError{userattributevalue.Label}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_d *UserAttributeValueDeleteOne) ExecX(ctx context.Context) {
|
||||||
|
if err := _d.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
681
backend/ent/userattributevalue_query.go
Normal file
681
backend/ent/userattributevalue_query.go
Normal file
@@ -0,0 +1,681 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"entgo.io/ent"
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeValueQuery is the builder for querying UserAttributeValue entities.
|
||||||
|
type UserAttributeValueQuery struct {
|
||||||
|
config
|
||||||
|
ctx *QueryContext
|
||||||
|
order []userattributevalue.OrderOption
|
||||||
|
inters []Interceptor
|
||||||
|
predicates []predicate.UserAttributeValue
|
||||||
|
withUser *UserQuery
|
||||||
|
withDefinition *UserAttributeDefinitionQuery
|
||||||
|
// intermediate query (i.e. traversal path).
|
||||||
|
sql *sql.Selector
|
||||||
|
path func(context.Context) (*sql.Selector, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where adds a new predicate for the UserAttributeValueQuery builder.
|
||||||
|
func (_q *UserAttributeValueQuery) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueQuery {
|
||||||
|
_q.predicates = append(_q.predicates, ps...)
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit the number of records to be returned by this query.
|
||||||
|
func (_q *UserAttributeValueQuery) Limit(limit int) *UserAttributeValueQuery {
|
||||||
|
_q.ctx.Limit = &limit
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offset to start from.
|
||||||
|
func (_q *UserAttributeValueQuery) Offset(offset int) *UserAttributeValueQuery {
|
||||||
|
_q.ctx.Offset = &offset
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unique configures the query builder to filter duplicate records on query.
|
||||||
|
// By default, unique is set to true, and can be disabled using this method.
|
||||||
|
func (_q *UserAttributeValueQuery) Unique(unique bool) *UserAttributeValueQuery {
|
||||||
|
_q.ctx.Unique = &unique
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Order specifies how the records should be ordered.
|
||||||
|
func (_q *UserAttributeValueQuery) Order(o ...userattributevalue.OrderOption) *UserAttributeValueQuery {
|
||||||
|
_q.order = append(_q.order, o...)
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryUser chains the current query on the "user" edge.
|
||||||
|
func (_q *UserAttributeValueQuery) QueryUser() *UserQuery {
|
||||||
|
query := (&UserClient{config: _q.config}).Query()
|
||||||
|
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||||
|
if err := _q.prepareQuery(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
selector := _q.sqlQuery(ctx)
|
||||||
|
if err := selector.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, selector),
|
||||||
|
sqlgraph.To(user.Table, user.FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.UserTable, userattributevalue.UserColumn),
|
||||||
|
)
|
||||||
|
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||||
|
return fromU, nil
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryDefinition chains the current query on the "definition" edge.
|
||||||
|
func (_q *UserAttributeValueQuery) QueryDefinition() *UserAttributeDefinitionQuery {
|
||||||
|
query := (&UserAttributeDefinitionClient{config: _q.config}).Query()
|
||||||
|
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||||
|
if err := _q.prepareQuery(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
selector := _q.sqlQuery(ctx)
|
||||||
|
if err := selector.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
step := sqlgraph.NewStep(
|
||||||
|
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, selector),
|
||||||
|
sqlgraph.To(userattributedefinition.Table, userattributedefinition.FieldID),
|
||||||
|
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.DefinitionTable, userattributevalue.DefinitionColumn),
|
||||||
|
)
|
||||||
|
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||||
|
return fromU, nil
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// First returns the first UserAttributeValue entity from the query.
|
||||||
|
// Returns a *NotFoundError when no UserAttributeValue was found.
|
||||||
|
func (_q *UserAttributeValueQuery) First(ctx context.Context) (*UserAttributeValue, error) {
|
||||||
|
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
return nil, &NotFoundError{userattributevalue.Label}
|
||||||
|
}
|
||||||
|
return nodes[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstX is like First, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeValueQuery) FirstX(ctx context.Context) *UserAttributeValue {
|
||||||
|
node, err := _q.First(ctx)
|
||||||
|
if err != nil && !IsNotFound(err) {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstID returns the first UserAttributeValue ID from the query.
|
||||||
|
// Returns a *NotFoundError when no UserAttributeValue ID was found.
|
||||||
|
func (_q *UserAttributeValueQuery) FirstID(ctx context.Context) (id int64, err error) {
|
||||||
|
var ids []int64
|
||||||
|
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
err = &NotFoundError{userattributevalue.Label}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return ids[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeValueQuery) FirstIDX(ctx context.Context) int64 {
|
||||||
|
id, err := _q.FirstID(ctx)
|
||||||
|
if err != nil && !IsNotFound(err) {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only returns a single UserAttributeValue entity found by the query, ensuring it only returns one.
|
||||||
|
// Returns a *NotSingularError when more than one UserAttributeValue entity is found.
|
||||||
|
// Returns a *NotFoundError when no UserAttributeValue entities are found.
|
||||||
|
func (_q *UserAttributeValueQuery) Only(ctx context.Context) (*UserAttributeValue, error) {
|
||||||
|
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch len(nodes) {
|
||||||
|
case 1:
|
||||||
|
return nodes[0], nil
|
||||||
|
case 0:
|
||||||
|
return nil, &NotFoundError{userattributevalue.Label}
|
||||||
|
default:
|
||||||
|
return nil, &NotSingularError{userattributevalue.Label}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlyX is like Only, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeValueQuery) OnlyX(ctx context.Context) *UserAttributeValue {
|
||||||
|
node, err := _q.Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlyID is like Only, but returns the only UserAttributeValue ID in the query.
|
||||||
|
// Returns a *NotSingularError when more than one UserAttributeValue ID is found.
|
||||||
|
// Returns a *NotFoundError when no entities are found.
|
||||||
|
func (_q *UserAttributeValueQuery) OnlyID(ctx context.Context) (id int64, err error) {
|
||||||
|
var ids []int64
|
||||||
|
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch len(ids) {
|
||||||
|
case 1:
|
||||||
|
id = ids[0]
|
||||||
|
case 0:
|
||||||
|
err = &NotFoundError{userattributevalue.Label}
|
||||||
|
default:
|
||||||
|
err = &NotSingularError{userattributevalue.Label}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeValueQuery) OnlyIDX(ctx context.Context) int64 {
|
||||||
|
id, err := _q.OnlyID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// All executes the query and returns a list of UserAttributeValues.
|
||||||
|
func (_q *UserAttributeValueQuery) All(ctx context.Context) ([]*UserAttributeValue, error) {
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
|
||||||
|
if err := _q.prepareQuery(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
qr := querierAll[[]*UserAttributeValue, *UserAttributeValueQuery]()
|
||||||
|
return withInterceptors[[]*UserAttributeValue](ctx, _q, qr, _q.inters)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllX is like All, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeValueQuery) AllX(ctx context.Context) []*UserAttributeValue {
|
||||||
|
nodes, err := _q.All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDs executes the query and returns a list of UserAttributeValue IDs.
|
||||||
|
func (_q *UserAttributeValueQuery) IDs(ctx context.Context) (ids []int64, err error) {
|
||||||
|
if _q.ctx.Unique == nil && _q.path != nil {
|
||||||
|
_q.Unique(true)
|
||||||
|
}
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
|
||||||
|
if err = _q.Select(userattributevalue.FieldID).Scan(ctx, &ids); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDsX is like IDs, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeValueQuery) IDsX(ctx context.Context) []int64 {
|
||||||
|
ids, err := _q.IDs(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the count of the given query.
|
||||||
|
func (_q *UserAttributeValueQuery) Count(ctx context.Context) (int, error) {
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
|
||||||
|
if err := _q.prepareQuery(ctx); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return withInterceptors[int](ctx, _q, querierCount[*UserAttributeValueQuery](), _q.inters)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountX is like Count, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeValueQuery) CountX(ctx context.Context) int {
|
||||||
|
count, err := _q.Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exist returns true if the query has elements in the graph.
|
||||||
|
func (_q *UserAttributeValueQuery) Exist(ctx context.Context) (bool, error) {
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
|
||||||
|
switch _, err := _q.FirstID(ctx); {
|
||||||
|
case IsNotFound(err):
|
||||||
|
return false, nil
|
||||||
|
case err != nil:
|
||||||
|
return false, fmt.Errorf("ent: check existence: %w", err)
|
||||||
|
default:
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExistX is like Exist, but panics if an error occurs.
|
||||||
|
func (_q *UserAttributeValueQuery) ExistX(ctx context.Context) bool {
|
||||||
|
exist, err := _q.Exist(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return exist
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone returns a duplicate of the UserAttributeValueQuery builder, including all associated steps. It can be
|
||||||
|
// used to prepare common query builders and use them differently after the clone is made.
|
||||||
|
func (_q *UserAttributeValueQuery) Clone() *UserAttributeValueQuery {
|
||||||
|
if _q == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UserAttributeValueQuery{
|
||||||
|
config: _q.config,
|
||||||
|
ctx: _q.ctx.Clone(),
|
||||||
|
order: append([]userattributevalue.OrderOption{}, _q.order...),
|
||||||
|
inters: append([]Interceptor{}, _q.inters...),
|
||||||
|
predicates: append([]predicate.UserAttributeValue{}, _q.predicates...),
|
||||||
|
withUser: _q.withUser.Clone(),
|
||||||
|
withDefinition: _q.withDefinition.Clone(),
|
||||||
|
// clone intermediate query.
|
||||||
|
sql: _q.sql.Clone(),
|
||||||
|
path: _q.path,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithUser tells the query-builder to eager-load the nodes that are connected to
|
||||||
|
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
|
||||||
|
func (_q *UserAttributeValueQuery) WithUser(opts ...func(*UserQuery)) *UserAttributeValueQuery {
|
||||||
|
query := (&UserClient{config: _q.config}).Query()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(query)
|
||||||
|
}
|
||||||
|
_q.withUser = query
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDefinition tells the query-builder to eager-load the nodes that are connected to
|
||||||
|
// the "definition" edge. The optional arguments are used to configure the query builder of the edge.
|
||||||
|
func (_q *UserAttributeValueQuery) WithDefinition(opts ...func(*UserAttributeDefinitionQuery)) *UserAttributeValueQuery {
|
||||||
|
query := (&UserAttributeDefinitionClient{config: _q.config}).Query()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(query)
|
||||||
|
}
|
||||||
|
_q.withDefinition = query
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupBy is used to group vertices by one or more fields/columns.
|
||||||
|
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// var v []struct {
|
||||||
|
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
|
// Count int `json:"count,omitempty"`
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// client.UserAttributeValue.Query().
|
||||||
|
// GroupBy(userattributevalue.FieldCreatedAt).
|
||||||
|
// Aggregate(ent.Count()).
|
||||||
|
// Scan(ctx, &v)
|
||||||
|
func (_q *UserAttributeValueQuery) GroupBy(field string, fields ...string) *UserAttributeValueGroupBy {
|
||||||
|
_q.ctx.Fields = append([]string{field}, fields...)
|
||||||
|
grbuild := &UserAttributeValueGroupBy{build: _q}
|
||||||
|
grbuild.flds = &_q.ctx.Fields
|
||||||
|
grbuild.label = userattributevalue.Label
|
||||||
|
grbuild.scan = grbuild.Scan
|
||||||
|
return grbuild
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select allows the selection one or more fields/columns for the given query,
|
||||||
|
// instead of selecting all fields in the entity.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// var v []struct {
|
||||||
|
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// client.UserAttributeValue.Query().
|
||||||
|
// Select(userattributevalue.FieldCreatedAt).
|
||||||
|
// Scan(ctx, &v)
|
||||||
|
func (_q *UserAttributeValueQuery) Select(fields ...string) *UserAttributeValueSelect {
|
||||||
|
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
|
||||||
|
sbuild := &UserAttributeValueSelect{UserAttributeValueQuery: _q}
|
||||||
|
sbuild.label = userattributevalue.Label
|
||||||
|
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
|
||||||
|
return sbuild
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate returns a UserAttributeValueSelect configured with the given aggregations.
|
||||||
|
func (_q *UserAttributeValueQuery) Aggregate(fns ...AggregateFunc) *UserAttributeValueSelect {
|
||||||
|
return _q.Select().Aggregate(fns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeValueQuery) prepareQuery(ctx context.Context) error {
|
||||||
|
for _, inter := range _q.inters {
|
||||||
|
if inter == nil {
|
||||||
|
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
|
||||||
|
}
|
||||||
|
if trv, ok := inter.(Traverser); ok {
|
||||||
|
if err := trv.Traverse(ctx, _q); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, f := range _q.ctx.Fields {
|
||||||
|
if !userattributevalue.ValidColumn(f) {
|
||||||
|
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _q.path != nil {
|
||||||
|
prev, err := _q.path(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_q.sql = prev
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserAttributeValue, error) {
|
||||||
|
var (
|
||||||
|
nodes = []*UserAttributeValue{}
|
||||||
|
_spec = _q.querySpec()
|
||||||
|
loadedTypes = [2]bool{
|
||||||
|
_q.withUser != nil,
|
||||||
|
_q.withDefinition != nil,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||||
|
return (*UserAttributeValue).scanValues(nil, columns)
|
||||||
|
}
|
||||||
|
_spec.Assign = func(columns []string, values []any) error {
|
||||||
|
node := &UserAttributeValue{config: _q.config}
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
node.Edges.loadedTypes = loadedTypes
|
||||||
|
return node.assignValues(columns, values)
|
||||||
|
}
|
||||||
|
for i := range hooks {
|
||||||
|
hooks[i](ctx, _spec)
|
||||||
|
}
|
||||||
|
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
if query := _q.withUser; query != nil {
|
||||||
|
if err := _q.loadUser(ctx, query, nodes, nil,
|
||||||
|
func(n *UserAttributeValue, e *User) { n.Edges.User = e }); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if query := _q.withDefinition; query != nil {
|
||||||
|
if err := _q.loadDefinition(ctx, query, nodes, nil,
|
||||||
|
func(n *UserAttributeValue, e *UserAttributeDefinition) { n.Edges.Definition = e }); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeValueQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserAttributeValue, init func(*UserAttributeValue), assign func(*UserAttributeValue, *User)) error {
|
||||||
|
ids := make([]int64, 0, len(nodes))
|
||||||
|
nodeids := make(map[int64][]*UserAttributeValue)
|
||||||
|
for i := range nodes {
|
||||||
|
fk := nodes[i].UserID
|
||||||
|
if _, ok := nodeids[fk]; !ok {
|
||||||
|
ids = append(ids, fk)
|
||||||
|
}
|
||||||
|
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
query.Where(user.IDIn(ids...))
|
||||||
|
neighbors, err := query.All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, n := range neighbors {
|
||||||
|
nodes, ok := nodeids[n.ID]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
|
||||||
|
}
|
||||||
|
for i := range nodes {
|
||||||
|
assign(nodes[i], n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *UserAttributeDefinitionQuery, nodes []*UserAttributeValue, init func(*UserAttributeValue), assign func(*UserAttributeValue, *UserAttributeDefinition)) error {
|
||||||
|
ids := make([]int64, 0, len(nodes))
|
||||||
|
nodeids := make(map[int64][]*UserAttributeValue)
|
||||||
|
for i := range nodes {
|
||||||
|
fk := nodes[i].AttributeID
|
||||||
|
if _, ok := nodeids[fk]; !ok {
|
||||||
|
ids = append(ids, fk)
|
||||||
|
}
|
||||||
|
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
query.Where(userattributedefinition.IDIn(ids...))
|
||||||
|
neighbors, err := query.All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, n := range neighbors {
|
||||||
|
nodes, ok := nodeids[n.ID]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf(`unexpected foreign-key "attribute_id" returned %v`, n.ID)
|
||||||
|
}
|
||||||
|
for i := range nodes {
|
||||||
|
assign(nodes[i], n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeValueQuery) sqlCount(ctx context.Context) (int, error) {
|
||||||
|
_spec := _q.querySpec()
|
||||||
|
_spec.Node.Columns = _q.ctx.Fields
|
||||||
|
if len(_q.ctx.Fields) > 0 {
|
||||||
|
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||||
|
}
|
||||||
|
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeValueQuery) querySpec() *sqlgraph.QuerySpec {
|
||||||
|
_spec := sqlgraph.NewQuerySpec(userattributevalue.Table, userattributevalue.Columns, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
|
||||||
|
_spec.From = _q.sql
|
||||||
|
if unique := _q.ctx.Unique; unique != nil {
|
||||||
|
_spec.Unique = *unique
|
||||||
|
} else if _q.path != nil {
|
||||||
|
_spec.Unique = true
|
||||||
|
}
|
||||||
|
if fields := _q.ctx.Fields; len(fields) > 0 {
|
||||||
|
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, userattributevalue.FieldID)
|
||||||
|
for i := range fields {
|
||||||
|
if fields[i] != userattributevalue.FieldID {
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _q.withUser != nil {
|
||||||
|
_spec.Node.AddColumnOnce(userattributevalue.FieldUserID)
|
||||||
|
}
|
||||||
|
if _q.withDefinition != nil {
|
||||||
|
_spec.Node.AddColumnOnce(userattributevalue.FieldAttributeID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ps := _q.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if limit := _q.ctx.Limit; limit != nil {
|
||||||
|
_spec.Limit = *limit
|
||||||
|
}
|
||||||
|
if offset := _q.ctx.Offset; offset != nil {
|
||||||
|
_spec.Offset = *offset
|
||||||
|
}
|
||||||
|
if ps := _q.order; len(ps) > 0 {
|
||||||
|
_spec.Order = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return _spec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||||
|
builder := sql.Dialect(_q.driver.Dialect())
|
||||||
|
t1 := builder.Table(userattributevalue.Table)
|
||||||
|
columns := _q.ctx.Fields
|
||||||
|
if len(columns) == 0 {
|
||||||
|
columns = userattributevalue.Columns
|
||||||
|
}
|
||||||
|
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
||||||
|
if _q.sql != nil {
|
||||||
|
selector = _q.sql
|
||||||
|
selector.Select(selector.Columns(columns...)...)
|
||||||
|
}
|
||||||
|
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||||
|
selector.Distinct()
|
||||||
|
}
|
||||||
|
for _, p := range _q.predicates {
|
||||||
|
p(selector)
|
||||||
|
}
|
||||||
|
for _, p := range _q.order {
|
||||||
|
p(selector)
|
||||||
|
}
|
||||||
|
if offset := _q.ctx.Offset; offset != nil {
|
||||||
|
// limit is mandatory for offset clause. We start
|
||||||
|
// with default value, and override it below if needed.
|
||||||
|
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||||
|
}
|
||||||
|
if limit := _q.ctx.Limit; limit != nil {
|
||||||
|
selector.Limit(*limit)
|
||||||
|
}
|
||||||
|
return selector
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities.
|
||||||
|
type UserAttributeValueGroupBy struct {
|
||||||
|
selector
|
||||||
|
build *UserAttributeValueQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate adds the given aggregation functions to the group-by query.
|
||||||
|
func (_g *UserAttributeValueGroupBy) Aggregate(fns ...AggregateFunc) *UserAttributeValueGroupBy {
|
||||||
|
_g.fns = append(_g.fns, fns...)
|
||||||
|
return _g
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan applies the selector query and scans the result into the given value.
|
||||||
|
func (_g *UserAttributeValueGroupBy) Scan(ctx context.Context, v any) error {
|
||||||
|
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
|
||||||
|
if err := _g.build.prepareQuery(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return scanWithInterceptors[*UserAttributeValueQuery, *UserAttributeValueGroupBy](ctx, _g.build, _g, _g.build.inters, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_g *UserAttributeValueGroupBy) sqlScan(ctx context.Context, root *UserAttributeValueQuery, v any) error {
|
||||||
|
selector := root.sqlQuery(ctx).Select()
|
||||||
|
aggregation := make([]string, 0, len(_g.fns))
|
||||||
|
for _, fn := range _g.fns {
|
||||||
|
aggregation = append(aggregation, fn(selector))
|
||||||
|
}
|
||||||
|
if len(selector.SelectedColumns()) == 0 {
|
||||||
|
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
|
||||||
|
for _, f := range *_g.flds {
|
||||||
|
columns = append(columns, selector.C(f))
|
||||||
|
}
|
||||||
|
columns = append(columns, aggregation...)
|
||||||
|
selector.Select(columns...)
|
||||||
|
}
|
||||||
|
selector.GroupBy(selector.Columns(*_g.flds...)...)
|
||||||
|
if err := selector.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rows := &sql.Rows{}
|
||||||
|
query, args := selector.Query()
|
||||||
|
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return sql.ScanSlice(rows, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueSelect is the builder for selecting fields of UserAttributeValue entities.
|
||||||
|
type UserAttributeValueSelect struct {
|
||||||
|
*UserAttributeValueQuery
|
||||||
|
selector
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate adds the given aggregation functions to the selector query.
|
||||||
|
func (_s *UserAttributeValueSelect) Aggregate(fns ...AggregateFunc) *UserAttributeValueSelect {
|
||||||
|
_s.fns = append(_s.fns, fns...)
|
||||||
|
return _s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan applies the selector query and scans the result into the given value.
|
||||||
|
func (_s *UserAttributeValueSelect) Scan(ctx context.Context, v any) error {
|
||||||
|
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
|
||||||
|
if err := _s.prepareQuery(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return scanWithInterceptors[*UserAttributeValueQuery, *UserAttributeValueSelect](ctx, _s.UserAttributeValueQuery, _s, _s.inters, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_s *UserAttributeValueSelect) sqlScan(ctx context.Context, root *UserAttributeValueQuery, v any) error {
|
||||||
|
selector := root.sqlQuery(ctx)
|
||||||
|
aggregation := make([]string, 0, len(_s.fns))
|
||||||
|
for _, fn := range _s.fns {
|
||||||
|
aggregation = append(aggregation, fn(selector))
|
||||||
|
}
|
||||||
|
switch n := len(*_s.selector.flds); {
|
||||||
|
case n == 0 && len(aggregation) > 0:
|
||||||
|
selector.Select(aggregation...)
|
||||||
|
case n != 0 && len(aggregation) > 0:
|
||||||
|
selector.AppendSelect(aggregation...)
|
||||||
|
}
|
||||||
|
rows := &sql.Rows{}
|
||||||
|
query, args := selector.Query()
|
||||||
|
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return sql.ScanSlice(rows, v)
|
||||||
|
}
|
||||||
504
backend/ent/userattributevalue_update.go
Normal file
504
backend/ent/userattributevalue_update.go
Normal file
@@ -0,0 +1,504 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeValueUpdate is the builder for updating UserAttributeValue entities.
|
||||||
|
type UserAttributeValueUpdate struct {
|
||||||
|
config
|
||||||
|
hooks []Hook
|
||||||
|
mutation *UserAttributeValueMutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the UserAttributeValueUpdate builder.
|
||||||
|
func (_u *UserAttributeValueUpdate) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueUpdate {
|
||||||
|
_u.mutation.Where(ps...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
|
func (_u *UserAttributeValueUpdate) SetUpdatedAt(v time.Time) *UserAttributeValueUpdate {
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserID sets the "user_id" field.
|
||||||
|
func (_u *UserAttributeValueUpdate) SetUserID(v int64) *UserAttributeValueUpdate {
|
||||||
|
_u.mutation.SetUserID(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUserID sets the "user_id" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeValueUpdate) SetNillableUserID(v *int64) *UserAttributeValueUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetUserID(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAttributeID sets the "attribute_id" field.
|
||||||
|
func (_u *UserAttributeValueUpdate) SetAttributeID(v int64) *UserAttributeValueUpdate {
|
||||||
|
_u.mutation.SetAttributeID(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAttributeID sets the "attribute_id" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeValueUpdate) SetNillableAttributeID(v *int64) *UserAttributeValueUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetAttributeID(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetValue sets the "value" field.
|
||||||
|
func (_u *UserAttributeValueUpdate) SetValue(v string) *UserAttributeValueUpdate {
|
||||||
|
_u.mutation.SetValue(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableValue sets the "value" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeValueUpdate) SetNillableValue(v *string) *UserAttributeValueUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetValue(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUser sets the "user" edge to the User entity.
|
||||||
|
func (_u *UserAttributeValueUpdate) SetUser(v *User) *UserAttributeValueUpdate {
|
||||||
|
return _u.SetUserID(v.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefinitionID sets the "definition" edge to the UserAttributeDefinition entity by ID.
|
||||||
|
func (_u *UserAttributeValueUpdate) SetDefinitionID(id int64) *UserAttributeValueUpdate {
|
||||||
|
_u.mutation.SetDefinitionID(id)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefinition sets the "definition" edge to the UserAttributeDefinition entity.
|
||||||
|
func (_u *UserAttributeValueUpdate) SetDefinition(v *UserAttributeDefinition) *UserAttributeValueUpdate {
|
||||||
|
return _u.SetDefinitionID(v.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutation returns the UserAttributeValueMutation object of the builder.
|
||||||
|
func (_u *UserAttributeValueUpdate) Mutation() *UserAttributeValueMutation {
|
||||||
|
return _u.mutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUser clears the "user" edge to the User entity.
|
||||||
|
func (_u *UserAttributeValueUpdate) ClearUser() *UserAttributeValueUpdate {
|
||||||
|
_u.mutation.ClearUser()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearDefinition clears the "definition" edge to the UserAttributeDefinition entity.
|
||||||
|
func (_u *UserAttributeValueUpdate) ClearDefinition() *UserAttributeValueUpdate {
|
||||||
|
_u.mutation.ClearDefinition()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||||
|
func (_u *UserAttributeValueUpdate) Save(ctx context.Context) (int, error) {
|
||||||
|
_u.defaults()
|
||||||
|
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveX is like Save, but panics if an error occurs.
|
||||||
|
func (_u *UserAttributeValueUpdate) SaveX(ctx context.Context) int {
|
||||||
|
affected, err := _u.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return affected
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the query.
|
||||||
|
func (_u *UserAttributeValueUpdate) Exec(ctx context.Context) error {
|
||||||
|
_, err := _u.Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_u *UserAttributeValueUpdate) ExecX(ctx context.Context) {
|
||||||
|
if err := _u.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaults sets the default values of the builder before save.
|
||||||
|
func (_u *UserAttributeValueUpdate) defaults() {
|
||||||
|
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||||
|
v := userattributevalue.UpdateDefaultUpdatedAt()
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check runs all checks and user-defined validators on the builder.
|
||||||
|
func (_u *UserAttributeValueUpdate) check() error {
|
||||||
|
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||||
|
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.user"`)
|
||||||
|
}
|
||||||
|
if _u.mutation.DefinitionCleared() && len(_u.mutation.DefinitionIDs()) > 0 {
|
||||||
|
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.definition"`)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_u *UserAttributeValueUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||||
|
if err := _u.check(); err != nil {
|
||||||
|
return _node, err
|
||||||
|
}
|
||||||
|
_spec := sqlgraph.NewUpdateSpec(userattributevalue.Table, userattributevalue.Columns, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
|
||||||
|
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||||
|
_spec.SetField(userattributevalue.FieldUpdatedAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Value(); ok {
|
||||||
|
_spec.SetField(userattributevalue.FieldValue, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.UserCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.M2O,
|
||||||
|
Inverse: true,
|
||||||
|
Table: userattributevalue.UserTable,
|
||||||
|
Columns: []string{userattributevalue.UserColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.M2O,
|
||||||
|
Inverse: true,
|
||||||
|
Table: userattributevalue.UserTable,
|
||||||
|
Columns: []string{userattributevalue.UserColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||||
|
}
|
||||||
|
if _u.mutation.DefinitionCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.M2O,
|
||||||
|
Inverse: true,
|
||||||
|
Table: userattributevalue.DefinitionTable,
|
||||||
|
Columns: []string{userattributevalue.DefinitionColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.DefinitionIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.M2O,
|
||||||
|
Inverse: true,
|
||||||
|
Table: userattributevalue.DefinitionTable,
|
||||||
|
Columns: []string{userattributevalue.DefinitionColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||||
|
}
|
||||||
|
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||||
|
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||||
|
err = &NotFoundError{userattributevalue.Label}
|
||||||
|
} else if sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_u.mutation.done = true
|
||||||
|
return _node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueUpdateOne is the builder for updating a single UserAttributeValue entity.
|
||||||
|
type UserAttributeValueUpdateOne struct {
|
||||||
|
config
|
||||||
|
fields []string
|
||||||
|
hooks []Hook
|
||||||
|
mutation *UserAttributeValueMutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SetUpdatedAt(v time.Time) *UserAttributeValueUpdateOne {
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserID sets the "user_id" field.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SetUserID(v int64) *UserAttributeValueUpdateOne {
|
||||||
|
_u.mutation.SetUserID(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUserID sets the "user_id" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SetNillableUserID(v *int64) *UserAttributeValueUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetUserID(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAttributeID sets the "attribute_id" field.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SetAttributeID(v int64) *UserAttributeValueUpdateOne {
|
||||||
|
_u.mutation.SetAttributeID(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAttributeID sets the "attribute_id" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SetNillableAttributeID(v *int64) *UserAttributeValueUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetAttributeID(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetValue sets the "value" field.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SetValue(v string) *UserAttributeValueUpdateOne {
|
||||||
|
_u.mutation.SetValue(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableValue sets the "value" field if the given value is not nil.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SetNillableValue(v *string) *UserAttributeValueUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetValue(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUser sets the "user" edge to the User entity.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SetUser(v *User) *UserAttributeValueUpdateOne {
|
||||||
|
return _u.SetUserID(v.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefinitionID sets the "definition" edge to the UserAttributeDefinition entity by ID.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SetDefinitionID(id int64) *UserAttributeValueUpdateOne {
|
||||||
|
_u.mutation.SetDefinitionID(id)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefinition sets the "definition" edge to the UserAttributeDefinition entity.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SetDefinition(v *UserAttributeDefinition) *UserAttributeValueUpdateOne {
|
||||||
|
return _u.SetDefinitionID(v.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutation returns the UserAttributeValueMutation object of the builder.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) Mutation() *UserAttributeValueMutation {
|
||||||
|
return _u.mutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUser clears the "user" edge to the User entity.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) ClearUser() *UserAttributeValueUpdateOne {
|
||||||
|
_u.mutation.ClearUser()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearDefinition clears the "definition" edge to the UserAttributeDefinition entity.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) ClearDefinition() *UserAttributeValueUpdateOne {
|
||||||
|
_u.mutation.ClearDefinition()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the UserAttributeValueUpdate builder.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueUpdateOne {
|
||||||
|
_u.mutation.Where(ps...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||||
|
// The default is selecting all fields defined in the entity schema.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) Select(field string, fields ...string) *UserAttributeValueUpdateOne {
|
||||||
|
_u.fields = append([]string{field}, fields...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save executes the query and returns the updated UserAttributeValue entity.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) Save(ctx context.Context) (*UserAttributeValue, error) {
|
||||||
|
_u.defaults()
|
||||||
|
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveX is like Save, but panics if an error occurs.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) SaveX(ctx context.Context) *UserAttributeValue {
|
||||||
|
node, err := _u.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the query on the entity.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) Exec(ctx context.Context) error {
|
||||||
|
_, err := _u.Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) ExecX(ctx context.Context) {
|
||||||
|
if err := _u.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaults sets the default values of the builder before save.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) defaults() {
|
||||||
|
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||||
|
v := userattributevalue.UpdateDefaultUpdatedAt()
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check runs all checks and user-defined validators on the builder.
|
||||||
|
func (_u *UserAttributeValueUpdateOne) check() error {
|
||||||
|
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||||
|
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.user"`)
|
||||||
|
}
|
||||||
|
if _u.mutation.DefinitionCleared() && len(_u.mutation.DefinitionIDs()) > 0 {
|
||||||
|
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.definition"`)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_u *UserAttributeValueUpdateOne) sqlSave(ctx context.Context) (_node *UserAttributeValue, err error) {
|
||||||
|
if err := _u.check(); err != nil {
|
||||||
|
return _node, err
|
||||||
|
}
|
||||||
|
_spec := sqlgraph.NewUpdateSpec(userattributevalue.Table, userattributevalue.Columns, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
|
||||||
|
id, ok := _u.mutation.ID()
|
||||||
|
if !ok {
|
||||||
|
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserAttributeValue.id" for update`)}
|
||||||
|
}
|
||||||
|
_spec.Node.ID.Value = id
|
||||||
|
if fields := _u.fields; len(fields) > 0 {
|
||||||
|
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, userattributevalue.FieldID)
|
||||||
|
for _, f := range fields {
|
||||||
|
if !userattributevalue.ValidColumn(f) {
|
||||||
|
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||||
|
}
|
||||||
|
if f != userattributevalue.FieldID {
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||||
|
_spec.SetField(userattributevalue.FieldUpdatedAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Value(); ok {
|
||||||
|
_spec.SetField(userattributevalue.FieldValue, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.UserCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.M2O,
|
||||||
|
Inverse: true,
|
||||||
|
Table: userattributevalue.UserTable,
|
||||||
|
Columns: []string{userattributevalue.UserColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.M2O,
|
||||||
|
Inverse: true,
|
||||||
|
Table: userattributevalue.UserTable,
|
||||||
|
Columns: []string{userattributevalue.UserColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||||
|
}
|
||||||
|
if _u.mutation.DefinitionCleared() {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.M2O,
|
||||||
|
Inverse: true,
|
||||||
|
Table: userattributevalue.DefinitionTable,
|
||||||
|
Columns: []string{userattributevalue.DefinitionColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||||
|
}
|
||||||
|
if nodes := _u.mutation.DefinitionIDs(); len(nodes) > 0 {
|
||||||
|
edge := &sqlgraph.EdgeSpec{
|
||||||
|
Rel: sqlgraph.M2O,
|
||||||
|
Inverse: true,
|
||||||
|
Table: userattributevalue.DefinitionTable,
|
||||||
|
Columns: []string{userattributevalue.DefinitionColumn},
|
||||||
|
Bidi: false,
|
||||||
|
Target: &sqlgraph.EdgeTarget{
|
||||||
|
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, k := range nodes {
|
||||||
|
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||||
|
}
|
||||||
|
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||||
|
}
|
||||||
|
_node = &UserAttributeValue{config: _u.config}
|
||||||
|
_spec.Assign = _node.assignValues
|
||||||
|
_spec.ScanValues = _node.scanValues
|
||||||
|
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
|
||||||
|
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||||
|
err = &NotFoundError{userattributevalue.Label}
|
||||||
|
} else if sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_u.mutation.done = true
|
||||||
|
return _node, nil
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package config
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
@@ -43,6 +44,7 @@ type Config struct {
|
|||||||
|
|
||||||
type GeminiConfig struct {
|
type GeminiConfig struct {
|
||||||
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
|
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
|
||||||
|
Quota GeminiQuotaConfig `mapstructure:"quota"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiOAuthConfig struct {
|
type GeminiOAuthConfig struct {
|
||||||
@@ -51,6 +53,17 @@ type GeminiOAuthConfig struct {
|
|||||||
Scopes string `mapstructure:"scopes"`
|
Scopes string `mapstructure:"scopes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GeminiQuotaConfig struct {
|
||||||
|
Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"`
|
||||||
|
Policy string `mapstructure:"policy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiTierQuotaConfig struct {
|
||||||
|
ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"`
|
||||||
|
FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"`
|
||||||
|
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
// TokenRefreshConfig OAuth token自动刷新配置
|
// TokenRefreshConfig OAuth token自动刷新配置
|
||||||
type TokenRefreshConfig struct {
|
type TokenRefreshConfig struct {
|
||||||
// 是否启用自动刷新
|
// 是否启用自动刷新
|
||||||
@@ -119,6 +132,37 @@ type GatewayConfig struct {
|
|||||||
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
|
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
|
||||||
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
|
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
|
||||||
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
|
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
|
||||||
|
|
||||||
|
// 是否记录上游错误响应体摘要(避免输出请求内容)
|
||||||
|
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
|
||||||
|
// 上游错误响应体记录最大字节数(超过会截断)
|
||||||
|
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
|
||||||
|
|
||||||
|
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
|
||||||
|
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
|
||||||
|
|
||||||
|
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||||
|
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||||
|
|
||||||
|
// Scheduling: 账号调度相关配置
|
||||||
|
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GatewaySchedulingConfig accounts scheduling configuration.
|
||||||
|
type GatewaySchedulingConfig struct {
|
||||||
|
// 粘性会话排队配置
|
||||||
|
StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"`
|
||||||
|
StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"`
|
||||||
|
|
||||||
|
// 兜底排队配置
|
||||||
|
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
|
||||||
|
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
|
||||||
|
|
||||||
|
// 负载计算
|
||||||
|
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
||||||
|
|
||||||
|
// 过期槽位清理周期(0 表示禁用)
|
||||||
|
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerConfig) Address() string {
|
func (s *ServerConfig) Address() string {
|
||||||
@@ -313,6 +357,10 @@ func setDefaults() {
|
|||||||
|
|
||||||
// Gateway
|
// Gateway
|
||||||
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||||
|
viper.SetDefault("gateway.log_upstream_error_body", false)
|
||||||
|
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
|
||||||
|
viper.SetDefault("gateway.inject_beta_for_apikey", false)
|
||||||
|
viper.SetDefault("gateway.failover_on_400", false)
|
||||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||||
@@ -323,6 +371,12 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.max_upstream_clients", 5000)
|
viper.SetDefault("gateway.max_upstream_clients", 5000)
|
||||||
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
|
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
|
||||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
|
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
|
||||||
|
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||||
|
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
|
||||||
|
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||||
|
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
||||||
|
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
||||||
|
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
||||||
|
|
||||||
// TokenRefresh
|
// TokenRefresh
|
||||||
viper.SetDefault("token_refresh.enabled", true)
|
viper.SetDefault("token_refresh.enabled", true)
|
||||||
@@ -337,6 +391,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gemini.oauth.client_id", "")
|
viper.SetDefault("gemini.oauth.client_id", "")
|
||||||
viper.SetDefault("gemini.oauth.client_secret", "")
|
viper.SetDefault("gemini.oauth.client_secret", "")
|
||||||
viper.SetDefault("gemini.oauth.scopes", "")
|
viper.SetDefault("gemini.oauth.scopes", "")
|
||||||
|
viper.SetDefault("gemini.quota.policy", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
@@ -411,6 +466,21 @@ func (c *Config) Validate() error {
|
|||||||
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
|
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
|
||||||
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
|
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
|
||||||
}
|
}
|
||||||
|
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
func TestNormalizeRunMode(t *testing.T) {
|
func TestNormalizeRunMode(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -21,3 +26,45 @@ func TestNormalizeRunMode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
|
||||||
|
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
|
||||||
|
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
|
||||||
|
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 {
|
||||||
|
t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting)
|
||||||
|
}
|
||||||
|
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
|
||||||
|
t.Fatalf("LoadBatchEnabled = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
|
||||||
|
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 {
|
||||||
|
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OAuthHandler handles OAuth-related operations for accounts
|
// OAuthHandler handles OAuth-related operations for accounts
|
||||||
@@ -989,3 +991,164 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, models)
|
response.Success(c, models)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshTier handles refreshing Google One tier for a single account
|
||||||
|
// POST /api/v1/admin/accounts/:id/refresh-tier
|
||||||
|
func (h *AccountHandler) RefreshTier(c *gin.Context) {
|
||||||
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid account ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "Account not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.Platform != service.PlatformGemini || account.Type != service.AccountTypeOAuth {
|
||||||
|
response.BadRequest(c, "Only Gemini OAuth accounts support tier refresh")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthType, _ := account.Credentials["oauth_type"].(string)
|
||||||
|
if oauthType != "google_one" {
|
||||||
|
response.BadRequest(c, "Only google_one OAuth accounts support tier refresh")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tierID, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, updateErr := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{
|
||||||
|
Credentials: creds,
|
||||||
|
Extra: extra,
|
||||||
|
})
|
||||||
|
if updateErr != nil {
|
||||||
|
response.ErrorFrom(c, updateErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"tier_id": tierID,
|
||||||
|
"storage_info": extra,
|
||||||
|
"drive_storage_limit": extra["drive_storage_limit"],
|
||||||
|
"drive_storage_usage": extra["drive_storage_usage"],
|
||||||
|
"updated_at": extra["drive_tier_updated_at"],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchRefreshTierRequest represents batch tier refresh request
|
||||||
|
type BatchRefreshTierRequest struct {
|
||||||
|
AccountIDs []int64 `json:"account_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchRefreshTier handles batch refreshing Google One tier
|
||||||
|
// POST /api/v1/admin/accounts/batch-refresh-tier
|
||||||
|
func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
||||||
|
var req BatchRefreshTierRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
req = BatchRefreshTierRequest{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
accounts := make([]*service.Account, 0)
|
||||||
|
|
||||||
|
if len(req.AccountIDs) == 0 {
|
||||||
|
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range allAccounts {
|
||||||
|
acc := &allAccounts[i]
|
||||||
|
oauthType, _ := acc.Credentials["oauth_type"].(string)
|
||||||
|
if oauthType == "google_one" {
|
||||||
|
accounts = append(accounts, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fetched, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, acc := range fetched {
|
||||||
|
if acc == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if acc.Platform != service.PlatformGemini || acc.Type != service.AccountTypeOAuth {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
oauthType, _ := acc.Credentials["oauth_type"].(string)
|
||||||
|
if oauthType != "google_one" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
accounts = append(accounts, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxConcurrency = 10
|
||||||
|
g, gctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(maxConcurrency)
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var successCount, failedCount int
|
||||||
|
var errors []gin.H
|
||||||
|
|
||||||
|
for _, account := range accounts {
|
||||||
|
acc := account // 闭包捕获
|
||||||
|
g.Go(func() error {
|
||||||
|
_, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(gctx, acc)
|
||||||
|
if err != nil {
|
||||||
|
mu.Lock()
|
||||||
|
failedCount++
|
||||||
|
errors = append(errors, gin.H{
|
||||||
|
"account_id": acc.ID,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, updateErr := h.adminService.UpdateAccount(gctx, acc.ID, &service.UpdateAccountInput{
|
||||||
|
Credentials: creds,
|
||||||
|
Extra: extra,
|
||||||
|
})
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
if updateErr != nil {
|
||||||
|
failedCount++
|
||||||
|
errors = append(errors, gin.H{
|
||||||
|
"account_id": acc.ID,
|
||||||
|
"error": updateErr.Error(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
successCount++
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
results := gin.H{
|
||||||
|
"total": len(accounts),
|
||||||
|
"success": successCount,
|
||||||
|
"failed": failedCount,
|
||||||
|
"errors": errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, results)
|
||||||
|
}
|
||||||
|
|||||||
@@ -46,8 +46,8 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
|||||||
if oauthType == "" {
|
if oauthType == "" {
|
||||||
oauthType = "code_assist"
|
oauthType = "code_assist"
|
||||||
}
|
}
|
||||||
if oauthType != "code_assist" && oauthType != "ai_studio" {
|
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
|
||||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
|
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,8 +92,8 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
|
|||||||
if oauthType == "" {
|
if oauthType == "" {
|
||||||
oauthType = "code_assist"
|
oauthType = "code_assist"
|
||||||
}
|
}
|
||||||
if oauthType != "code_assist" && oauthType != "ai_studio" {
|
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
|
||||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
|
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Limit to 30 results
|
// Limit to 30 results
|
||||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
|
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
342
backend/internal/handler/admin/user_attribute_handler.go
Normal file
342
backend/internal/handler/admin/user_attribute_handler.go
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeHandler handles user attribute management
|
||||||
|
type UserAttributeHandler struct {
|
||||||
|
attrService *service.UserAttributeService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserAttributeHandler creates a new handler
|
||||||
|
func NewUserAttributeHandler(attrService *service.UserAttributeService) *UserAttributeHandler {
|
||||||
|
return &UserAttributeHandler{attrService: attrService}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Request/Response DTOs ---
|
||||||
|
|
||||||
|
// CreateAttributeDefinitionRequest represents create attribute definition request
|
||||||
|
type CreateAttributeDefinitionRequest struct {
|
||||||
|
Key string `json:"key" binding:"required,min=1,max=100"`
|
||||||
|
Name string `json:"name" binding:"required,min=1,max=255"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Type string `json:"type" binding:"required"`
|
||||||
|
Options []service.UserAttributeOption `json:"options"`
|
||||||
|
Required bool `json:"required"`
|
||||||
|
Validation service.UserAttributeValidation `json:"validation"`
|
||||||
|
Placeholder string `json:"placeholder"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAttributeDefinitionRequest represents update attribute definition request
|
||||||
|
type UpdateAttributeDefinitionRequest struct {
|
||||||
|
Name *string `json:"name"`
|
||||||
|
Description *string `json:"description"`
|
||||||
|
Type *string `json:"type"`
|
||||||
|
Options *[]service.UserAttributeOption `json:"options"`
|
||||||
|
Required *bool `json:"required"`
|
||||||
|
Validation *service.UserAttributeValidation `json:"validation"`
|
||||||
|
Placeholder *string `json:"placeholder"`
|
||||||
|
Enabled *bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReorderRequest represents reorder attribute definitions request
|
||||||
|
type ReorderRequest struct {
|
||||||
|
IDs []int64 `json:"ids" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserAttributesRequest represents update user attributes request
|
||||||
|
type UpdateUserAttributesRequest struct {
|
||||||
|
Values map[int64]string `json:"values" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchGetUserAttributesRequest represents batch get user attributes request
|
||||||
|
type BatchGetUserAttributesRequest struct {
|
||||||
|
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchUserAttributesResponse represents batch user attributes response
|
||||||
|
type BatchUserAttributesResponse struct {
|
||||||
|
// Map of userID -> map of attributeID -> value
|
||||||
|
Attributes map[int64]map[int64]string `json:"attributes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttributeDefinitionResponse represents attribute definition response
|
||||||
|
type AttributeDefinitionResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Key string `json:"key"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Options []service.UserAttributeOption `json:"options"`
|
||||||
|
Required bool `json:"required"`
|
||||||
|
Validation service.UserAttributeValidation `json:"validation"`
|
||||||
|
Placeholder string `json:"placeholder"`
|
||||||
|
DisplayOrder int `json:"display_order"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttributeValueResponse represents attribute value response
|
||||||
|
type AttributeValueResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
AttributeID int64 `json:"attribute_id"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helpers ---
|
||||||
|
|
||||||
|
func defToResponse(def *service.UserAttributeDefinition) *AttributeDefinitionResponse {
|
||||||
|
return &AttributeDefinitionResponse{
|
||||||
|
ID: def.ID,
|
||||||
|
Key: def.Key,
|
||||||
|
Name: def.Name,
|
||||||
|
Description: def.Description,
|
||||||
|
Type: string(def.Type),
|
||||||
|
Options: def.Options,
|
||||||
|
Required: def.Required,
|
||||||
|
Validation: def.Validation,
|
||||||
|
Placeholder: def.Placeholder,
|
||||||
|
DisplayOrder: def.DisplayOrder,
|
||||||
|
Enabled: def.Enabled,
|
||||||
|
CreatedAt: def.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||||
|
UpdatedAt: def.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func valueToResponse(val *service.UserAttributeValue) *AttributeValueResponse {
|
||||||
|
return &AttributeValueResponse{
|
||||||
|
ID: val.ID,
|
||||||
|
UserID: val.UserID,
|
||||||
|
AttributeID: val.AttributeID,
|
||||||
|
Value: val.Value,
|
||||||
|
CreatedAt: val.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||||
|
UpdatedAt: val.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Handlers ---
|
||||||
|
|
||||||
|
// ListDefinitions lists all attribute definitions
|
||||||
|
// GET /admin/user-attributes
|
||||||
|
func (h *UserAttributeHandler) ListDefinitions(c *gin.Context) {
|
||||||
|
enabledOnly := c.Query("enabled") == "true"
|
||||||
|
|
||||||
|
defs, err := h.attrService.ListDefinitions(c.Request.Context(), enabledOnly)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*AttributeDefinitionResponse, 0, len(defs))
|
||||||
|
for i := range defs {
|
||||||
|
out = append(out, defToResponse(&defs[i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateDefinition creates a new attribute definition
|
||||||
|
// POST /admin/user-attributes
|
||||||
|
func (h *UserAttributeHandler) CreateDefinition(c *gin.Context) {
|
||||||
|
var req CreateAttributeDefinitionRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
def, err := h.attrService.CreateDefinition(c.Request.Context(), service.CreateAttributeDefinitionInput{
|
||||||
|
Key: req.Key,
|
||||||
|
Name: req.Name,
|
||||||
|
Description: req.Description,
|
||||||
|
Type: service.UserAttributeType(req.Type),
|
||||||
|
Options: req.Options,
|
||||||
|
Required: req.Required,
|
||||||
|
Validation: req.Validation,
|
||||||
|
Placeholder: req.Placeholder,
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, defToResponse(def))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDefinition updates an attribute definition
|
||||||
|
// PUT /admin/user-attributes/:id
|
||||||
|
func (h *UserAttributeHandler) UpdateDefinition(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid attribute ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req UpdateAttributeDefinitionRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
input := service.UpdateAttributeDefinitionInput{
|
||||||
|
Name: req.Name,
|
||||||
|
Description: req.Description,
|
||||||
|
Options: req.Options,
|
||||||
|
Required: req.Required,
|
||||||
|
Validation: req.Validation,
|
||||||
|
Placeholder: req.Placeholder,
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
}
|
||||||
|
if req.Type != nil {
|
||||||
|
t := service.UserAttributeType(*req.Type)
|
||||||
|
input.Type = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
def, err := h.attrService.UpdateDefinition(c.Request.Context(), id, input)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, defToResponse(def))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDefinition deletes an attribute definition
|
||||||
|
// DELETE /admin/user-attributes/:id
|
||||||
|
func (h *UserAttributeHandler) DeleteDefinition(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid attribute ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.attrService.DeleteDefinition(c.Request.Context(), id); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Attribute definition deleted successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReorderDefinitions reorders attribute definitions
|
||||||
|
// PUT /admin/user-attributes/reorder
|
||||||
|
func (h *UserAttributeHandler) ReorderDefinitions(c *gin.Context) {
|
||||||
|
var req ReorderRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert IDs array to orders map (position in array = display_order)
|
||||||
|
orders := make(map[int64]int, len(req.IDs))
|
||||||
|
for i, id := range req.IDs {
|
||||||
|
orders[id] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.attrService.ReorderDefinitions(c.Request.Context(), orders); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Reorder successful"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserAttributes gets a user's attribute values
|
||||||
|
// GET /admin/users/:id/attributes
|
||||||
|
func (h *UserAttributeHandler) GetUserAttributes(c *gin.Context) {
|
||||||
|
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*AttributeValueResponse, 0, len(values))
|
||||||
|
for i := range values {
|
||||||
|
out = append(out, valueToResponse(&values[i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserAttributes updates a user's attribute values
|
||||||
|
// PUT /admin/users/:id/attributes
|
||||||
|
func (h *UserAttributeHandler) UpdateUserAttributes(c *gin.Context) {
|
||||||
|
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req UpdateUserAttributesRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs := make([]service.UpdateUserAttributeInput, 0, len(req.Values))
|
||||||
|
for attrID, value := range req.Values {
|
||||||
|
inputs = append(inputs, service.UpdateUserAttributeInput{
|
||||||
|
AttributeID: attrID,
|
||||||
|
Value: value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.attrService.UpdateUserAttributes(c.Request.Context(), userID, inputs); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return updated values
|
||||||
|
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*AttributeValueResponse, 0, len(values))
|
||||||
|
for i := range values {
|
||||||
|
out = append(out, valueToResponse(&values[i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBatchUserAttributes gets attribute values for multiple users
|
||||||
|
// POST /admin/user-attributes/batch
|
||||||
|
func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
|
||||||
|
var req BatchGetUserAttributesRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.UserIDs) == 0 {
|
||||||
|
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
|
||||||
|
}
|
||||||
@@ -27,7 +27,6 @@ type CreateUserRequest struct {
|
|||||||
Email string `json:"email" binding:"required,email"`
|
Email string `json:"email" binding:"required,email"`
|
||||||
Password string `json:"password" binding:"required,min=6"`
|
Password string `json:"password" binding:"required,min=6"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Wechat string `json:"wechat"`
|
|
||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
Balance float64 `json:"balance"`
|
Balance float64 `json:"balance"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
@@ -40,7 +39,6 @@ type UpdateUserRequest struct {
|
|||||||
Email string `json:"email" binding:"omitempty,email"`
|
Email string `json:"email" binding:"omitempty,email"`
|
||||||
Password string `json:"password" binding:"omitempty,min=6"`
|
Password string `json:"password" binding:"omitempty,min=6"`
|
||||||
Username *string `json:"username"`
|
Username *string `json:"username"`
|
||||||
Wechat *string `json:"wechat"`
|
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Balance *float64 `json:"balance"`
|
Balance *float64 `json:"balance"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
@@ -57,13 +55,22 @@ type UpdateBalanceRequest struct {
|
|||||||
|
|
||||||
// List handles listing all users with pagination
|
// List handles listing all users with pagination
|
||||||
// GET /api/v1/admin/users
|
// GET /api/v1/admin/users
|
||||||
|
// Query params:
|
||||||
|
// - status: filter by user status
|
||||||
|
// - role: filter by user role
|
||||||
|
// - search: search in email, username
|
||||||
|
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
|
||||||
func (h *UserHandler) List(c *gin.Context) {
|
func (h *UserHandler) List(c *gin.Context) {
|
||||||
page, pageSize := response.ParsePagination(c)
|
page, pageSize := response.ParsePagination(c)
|
||||||
status := c.Query("status")
|
|
||||||
role := c.Query("role")
|
|
||||||
search := c.Query("search")
|
|
||||||
|
|
||||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
|
filters := service.UserListFilters{
|
||||||
|
Status: c.Query("status"),
|
||||||
|
Role: c.Query("role"),
|
||||||
|
Search: c.Query("search"),
|
||||||
|
Attributes: parseAttributeFilters(c),
|
||||||
|
}
|
||||||
|
|
||||||
|
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -76,6 +83,29 @@ func (h *UserHandler) List(c *gin.Context) {
|
|||||||
response.Paginated(c, out, total, page, pageSize)
|
response.Paginated(c, out, total, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseAttributeFilters extracts attribute filters from query params
|
||||||
|
// Format: attr[{attributeID}]=value, e.g. attr[1]=company&attr[2]=developer
|
||||||
|
func parseAttributeFilters(c *gin.Context) map[int64]string {
|
||||||
|
result := make(map[int64]string)
|
||||||
|
|
||||||
|
// Get all query params and look for attr[*] pattern
|
||||||
|
for key, values := range c.Request.URL.Query() {
|
||||||
|
if len(values) == 0 || values[0] == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Check if key matches pattern attr[{id}]
|
||||||
|
if len(key) > 5 && key[:5] == "attr[" && key[len(key)-1] == ']' {
|
||||||
|
idStr := key[5 : len(key)-1]
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err == nil && id > 0 {
|
||||||
|
result[id] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// GetByID handles getting a user by ID
|
// GetByID handles getting a user by ID
|
||||||
// GET /api/v1/admin/users/:id
|
// GET /api/v1/admin/users/:id
|
||||||
func (h *UserHandler) GetByID(c *gin.Context) {
|
func (h *UserHandler) GetByID(c *gin.Context) {
|
||||||
@@ -107,7 +137,6 @@ func (h *UserHandler) Create(c *gin.Context) {
|
|||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Password: req.Password,
|
Password: req.Password,
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Wechat: req.Wechat,
|
|
||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
Balance: req.Balance,
|
Balance: req.Balance,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
@@ -141,7 +170,6 @@ func (h *UserHandler) Update(c *gin.Context) {
|
|||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Password: req.Password,
|
Password: req.Password,
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Wechat: req.Wechat,
|
|
||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
Balance: req.Balance,
|
Balance: req.Balance,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ func UserFromServiceShallow(u *service.User) *User {
|
|||||||
ID: u.ID,
|
ID: u.ID,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
Username: u.Username,
|
Username: u.Username,
|
||||||
Wechat: u.Wechat,
|
|
||||||
Notes: u.Notes,
|
Notes: u.Notes,
|
||||||
Role: u.Role,
|
Role: u.Role,
|
||||||
Balance: u.Balance,
|
Balance: u.Balance,
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ type User struct {
|
|||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Wechat string `json:"wechat"`
|
|
||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Balance float64 `json:"balance"`
|
Balance float64 `json:"balance"`
|
||||||
|
|||||||
@@ -142,6 +142,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
} else if apiKey.Group != nil {
|
} else if apiKey.Group != nil {
|
||||||
platform = apiKey.Group.Platform
|
platform = apiKey.Group.Platform
|
||||||
}
|
}
|
||||||
|
sessionKey := sessionHash
|
||||||
|
if platform == service.PlatformGemini && sessionHash != "" {
|
||||||
|
sessionKey = "gemini:" + sessionHash
|
||||||
|
}
|
||||||
|
|
||||||
if platform == service.PlatformGemini {
|
if platform == service.PlatformGemini {
|
||||||
const maxAccountSwitches = 3
|
const maxAccountSwitches = 3
|
||||||
@@ -150,7 +154,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
|
|
||||||
for {
|
for {
|
||||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
@@ -159,9 +163,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
account := selection.Account
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||||
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
if reqStream {
|
if reqStream {
|
||||||
sendMockWarmupStream(c, reqModel)
|
sendMockWarmupStream(c, reqModel)
|
||||||
} else {
|
} else {
|
||||||
@@ -171,11 +179,46 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3. 获取账号并发槽位
|
// 3. 获取账号并发槽位
|
||||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
if err != nil {
|
var accountWaitRelease func()
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
if !selection.Acquired {
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
if selection.WaitPlan == nil {
|
||||||
return
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Increment account wait count failed: %v", err)
|
||||||
|
} else if !canWait {
|
||||||
|
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// Only set release function if increment succeeded
|
||||||
|
accountWaitRelease = func() {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
|
c,
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
selection.WaitPlan.Timeout,
|
||||||
|
reqStream,
|
||||||
|
&streamStarted,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if accountWaitRelease != nil {
|
||||||
|
accountWaitRelease()
|
||||||
|
}
|
||||||
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||||
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
@@ -188,6 +231,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
|
if accountWaitRelease != nil {
|
||||||
|
accountWaitRelease()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
@@ -232,7 +278,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
// 选择支持该模型的账号
|
// 选择支持该模型的账号
|
||||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
@@ -241,9 +287,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
account := selection.Account
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||||
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
if reqStream {
|
if reqStream {
|
||||||
sendMockWarmupStream(c, reqModel)
|
sendMockWarmupStream(c, reqModel)
|
||||||
} else {
|
} else {
|
||||||
@@ -253,11 +303,46 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3. 获取账号并发槽位
|
// 3. 获取账号并发槽位
|
||||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
if err != nil {
|
var accountWaitRelease func()
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
if !selection.Acquired {
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
if selection.WaitPlan == nil {
|
||||||
return
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Increment account wait count failed: %v", err)
|
||||||
|
} else if !canWait {
|
||||||
|
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// Only set release function if increment succeeded
|
||||||
|
accountWaitRelease = func() {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
|
c,
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
selection.WaitPlan.Timeout,
|
||||||
|
reqStream,
|
||||||
|
&streamStarted,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if accountWaitRelease != nil {
|
||||||
|
accountWaitRelease()
|
||||||
|
}
|
||||||
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||||
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
@@ -270,6 +355,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
|
if accountWaitRelease != nil {
|
||||||
|
accountWaitRelease()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
@@ -309,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
// Models handles listing available models
|
// Models handles listing available models
|
||||||
// GET /v1/models
|
// GET /v1/models
|
||||||
// Returns different model lists based on the API key's group platform or forced platform
|
// Returns models based on account configurations (model_mapping whitelist)
|
||||||
|
// Falls back to default models if no whitelist is configured
|
||||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||||
apiKey, _ := middleware2.GetApiKeyFromContext(c)
|
apiKey, _ := middleware2.GetApiKeyFromContext(c)
|
||||||
|
|
||||||
@@ -324,8 +413,37 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return OpenAI models for OpenAI platform groups
|
var groupID *int64
|
||||||
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
|
var platform string
|
||||||
|
|
||||||
|
if apiKey != nil && apiKey.Group != nil {
|
||||||
|
groupID = &apiKey.Group.ID
|
||||||
|
platform = apiKey.Group.Platform
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get available models from account configurations (without platform filter)
|
||||||
|
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
|
||||||
|
|
||||||
|
if len(availableModels) > 0 {
|
||||||
|
// Build model list from whitelist
|
||||||
|
models := make([]claude.Model, 0, len(availableModels))
|
||||||
|
for _, modelID := range availableModels {
|
||||||
|
models = append(models, claude.Model{
|
||||||
|
ID: modelID,
|
||||||
|
Type: "model",
|
||||||
|
DisplayName: modelID,
|
||||||
|
CreatedAt: "2024-01-01T00:00:00Z",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"object": "list",
|
||||||
|
"data": models,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to default models
|
||||||
|
if platform == "openai" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": openai.DefaultModels,
|
"data": openai.DefaultModels,
|
||||||
@@ -333,7 +451,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default: Claude models
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": claude.DefaultModels,
|
"data": claude.DefaultModels,
|
||||||
|
|||||||
@@ -83,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
|
|||||||
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IncrementAccountWaitCount increments the wait count for an account
|
||||||
|
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
|
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecrementAccountWaitCount decrements the wait count for an account
|
||||||
|
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||||||
|
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||||
// For streaming requests, sends ping events during the wait.
|
// For streaming requests, sends ping events during the wait.
|
||||||
// streamStarted is updated if streaming response has begun.
|
// streamStarted is updated if streaming response has begun.
|
||||||
@@ -126,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
|
|||||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
|
||||||
|
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Determine if ping is needed (streaming + ping format defined)
|
// Determine if ping is needed (streaming + ping format defined)
|
||||||
@@ -200,6 +215,11 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
|
||||||
|
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||||
|
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
|
||||||
|
}
|
||||||
|
|
||||||
// nextBackoff 计算下一次退避时间
|
// nextBackoff 计算下一次退避时间
|
||||||
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
||||||
// current: 当前退避时间
|
// current: 当前退避时间
|
||||||
|
|||||||
@@ -198,13 +198,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
// 3) select account (sticky session based on request body)
|
// 3) select account (sticky session based on request body)
|
||||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||||
|
sessionKey := sessionHash
|
||||||
|
if sessionHash != "" {
|
||||||
|
sessionKey = "gemini:" + sessionHash
|
||||||
|
}
|
||||||
const maxAccountSwitches = 3
|
const maxAccountSwitches = 3
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
|
|
||||||
for {
|
for {
|
||||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||||
@@ -213,12 +217,48 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
account := selection.Account
|
||||||
|
|
||||||
// 4) account concurrency slot
|
// 4) account concurrency slot
|
||||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
if err != nil {
|
var accountWaitRelease func()
|
||||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
if !selection.Acquired {
|
||||||
return
|
if selection.WaitPlan == nil {
|
||||||
|
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Increment account wait count failed: %v", err)
|
||||||
|
} else if !canWait {
|
||||||
|
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||||
|
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// Only set release function if increment succeeded
|
||||||
|
accountWaitRelease = func() {
|
||||||
|
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
|
||||||
|
c,
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
selection.WaitPlan.Timeout,
|
||||||
|
stream,
|
||||||
|
&streamStarted,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if accountWaitRelease != nil {
|
||||||
|
accountWaitRelease()
|
||||||
|
}
|
||||||
|
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||||
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5) forward (根据平台分流)
|
// 5) forward (根据平台分流)
|
||||||
@@ -231,6 +271,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
|
if accountWaitRelease != nil {
|
||||||
|
accountWaitRelease()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type AdminHandlers struct {
|
|||||||
System *admin.SystemHandler
|
System *admin.SystemHandler
|
||||||
Subscription *admin.SubscriptionHandler
|
Subscription *admin.SubscriptionHandler
|
||||||
Usage *admin.UsageHandler
|
Usage *admin.UsageHandler
|
||||||
|
UserAttribute *admin.UserAttributeHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handlers contains all HTTP handlers
|
// Handlers contains all HTTP handlers
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
// Select account supporting the requested model
|
// Select account supporting the requested model
|
||||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
@@ -156,14 +156,50 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
account := selection.Account
|
||||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||||
|
|
||||||
// 3. Acquire account concurrency slot
|
// 3. Acquire account concurrency slot
|
||||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
if err != nil {
|
var accountWaitRelease func()
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
if !selection.Acquired {
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
if selection.WaitPlan == nil {
|
||||||
return
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Increment account wait count failed: %v", err)
|
||||||
|
} else if !canWait {
|
||||||
|
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// Only set release function if increment succeeded
|
||||||
|
accountWaitRelease = func() {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
|
c,
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
selection.WaitPlan.Timeout,
|
||||||
|
reqStream,
|
||||||
|
&streamStarted,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if accountWaitRelease != nil {
|
||||||
|
accountWaitRelease()
|
||||||
|
}
|
||||||
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
|
||||||
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward request
|
// Forward request
|
||||||
@@ -171,6 +207,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
|
if accountWaitRelease != nil {
|
||||||
|
accountWaitRelease()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ type ChangePasswordRequest struct {
|
|||||||
// UpdateProfileRequest represents the update profile request payload
|
// UpdateProfileRequest represents the update profile request payload
|
||||||
type UpdateProfileRequest struct {
|
type UpdateProfileRequest struct {
|
||||||
Username *string `json:"username"`
|
Username *string `json:"username"`
|
||||||
Wechat *string `json:"wechat"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProfile handles getting user profile
|
// GetProfile handles getting user profile
|
||||||
@@ -99,7 +98,6 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
|||||||
|
|
||||||
svcReq := service.UpdateProfileRequest{
|
svcReq := service.UpdateProfileRequest{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Wechat: req.Wechat,
|
|
||||||
}
|
}
|
||||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
|
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ func ProvideAdminHandlers(
|
|||||||
systemHandler *admin.SystemHandler,
|
systemHandler *admin.SystemHandler,
|
||||||
subscriptionHandler *admin.SubscriptionHandler,
|
subscriptionHandler *admin.SubscriptionHandler,
|
||||||
usageHandler *admin.UsageHandler,
|
usageHandler *admin.UsageHandler,
|
||||||
|
userAttributeHandler *admin.UserAttributeHandler,
|
||||||
) *AdminHandlers {
|
) *AdminHandlers {
|
||||||
return &AdminHandlers{
|
return &AdminHandlers{
|
||||||
Dashboard: dashboardHandler,
|
Dashboard: dashboardHandler,
|
||||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
|||||||
System: systemHandler,
|
System: systemHandler,
|
||||||
Subscription: subscriptionHandler,
|
Subscription: subscriptionHandler,
|
||||||
Usage: usageHandler,
|
Usage: usageHandler,
|
||||||
|
UserAttribute: userAttributeHandler,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,6 +109,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
ProvideSystemHandler,
|
ProvideSystemHandler,
|
||||||
admin.NewSubscriptionHandler,
|
admin.NewSubscriptionHandler,
|
||||||
admin.NewUsageHandler,
|
admin.NewUsageHandler,
|
||||||
|
admin.NewUserAttributeHandler,
|
||||||
|
|
||||||
// AdminHandlers and Handlers constructors
|
// AdminHandlers and Handlers constructors
|
||||||
ProvideAdminHandlers,
|
ProvideAdminHandlers,
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ var geminiModels = []string{
|
|||||||
"gemini-2.5-flash-lite",
|
"gemini-2.5-flash-lite",
|
||||||
"gemini-3-flash",
|
"gemini-3-flash",
|
||||||
"gemini-3-pro-low",
|
"gemini-3-pro-low",
|
||||||
|
"gemini-3-pro-high",
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
@@ -641,6 +642,37 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
|
|||||||
t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"])
|
t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestClaudeMessagesWithGeminiModel 测试在 Claude 端点使用 Gemini 模型
|
||||||
|
// 验证:通过 /v1/messages 端点传入 gemini 模型名的场景(含前缀映射)
|
||||||
|
// 仅在 Antigravity 模式下运行(ENDPOINT_PREFIX="/antigravity")
|
||||||
|
func TestClaudeMessagesWithGeminiModel(t *testing.T) {
|
||||||
|
if endpointPrefix != "/antigravity" {
|
||||||
|
t.Skip("仅在 Antigravity 模式下运行")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试通过 Claude 端点调用 Gemini 模型
|
||||||
|
geminiViaClaude := []string{
|
||||||
|
"gemini-3-flash", // 直接支持
|
||||||
|
"gemini-3-pro-low", // 直接支持
|
||||||
|
"gemini-3-pro-high", // 直接支持
|
||||||
|
"gemini-3-pro", // 前缀映射 -> gemini-3-pro-high
|
||||||
|
"gemini-3-pro-preview", // 前缀映射 -> gemini-3-pro-high
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, model := range geminiViaClaude {
|
||||||
|
if i > 0 {
|
||||||
|
time.Sleep(testInterval)
|
||||||
|
}
|
||||||
|
t.Run(model+"_通过Claude端点", func(t *testing.T) {
|
||||||
|
testClaudeMessage(t, model, false)
|
||||||
|
})
|
||||||
|
time.Sleep(testInterval)
|
||||||
|
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
|
||||||
|
testClaudeMessage(t, model, true)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
|
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
|
||||||
// 验证:Gemini 模型接受没有 signature 的 thinking block
|
// 验证:Gemini 模型接受没有 signature 的 thinking block
|
||||||
func TestClaudeMessagesWithNoSignature(t *testing.T) {
|
func TestClaudeMessagesWithNoSignature(t *testing.T) {
|
||||||
@@ -738,3 +770,30 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
|
|||||||
}
|
}
|
||||||
t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"])
|
t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGeminiEndpointWithClaudeModel 测试通过 Gemini 端点调用 Claude 模型
|
||||||
|
// 仅在 Antigravity 模式下运行(ENDPOINT_PREFIX="/antigravity")
|
||||||
|
func TestGeminiEndpointWithClaudeModel(t *testing.T) {
|
||||||
|
if endpointPrefix != "/antigravity" {
|
||||||
|
t.Skip("仅在 Antigravity 模式下运行")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试通过 Gemini 端点调用 Claude 模型
|
||||||
|
claudeViaGemini := []string{
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
"claude-opus-4-5-thinking",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, model := range claudeViaGemini {
|
||||||
|
if i > 0 {
|
||||||
|
time.Sleep(testInterval)
|
||||||
|
}
|
||||||
|
t.Run(model+"_通过Gemini端点", func(t *testing.T) {
|
||||||
|
testGeminiGenerate(t, model, false)
|
||||||
|
})
|
||||||
|
time.Sleep(testInterval)
|
||||||
|
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
|
||||||
|
testGeminiGenerate(t, model, true)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -54,6 +54,9 @@ type CustomToolSpec struct {
|
|||||||
InputSchema map[string]any `json:"input_schema"`
|
InputSchema map[string]any `json:"input_schema"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
|
||||||
|
type ClaudeCustomToolSpec = CustomToolSpec
|
||||||
|
|
||||||
// SystemBlock system prompt 数组形式的元素
|
// SystemBlock system prompt 数组形式的元素
|
||||||
type SystemBlock struct {
|
type SystemBlock struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|||||||
@@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
|
|||||||
// 用于存储 tool_use id -> name 映射
|
// 用于存储 tool_use id -> name 映射
|
||||||
toolIDToName := make(map[string]string)
|
toolIDToName := make(map[string]string)
|
||||||
|
|
||||||
// 检测是否启用 thinking
|
|
||||||
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
|
||||||
|
|
||||||
// 只有 Gemini 模型支持 dummy thought workaround
|
// 只有 Gemini 模型支持 dummy thought workaround
|
||||||
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
||||||
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
|
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
|
||||||
|
|
||||||
|
// 检测是否启用 thinking
|
||||||
|
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||||||
|
// 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等),
|
||||||
|
// 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。
|
||||||
|
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
|
||||||
|
|
||||||
// 1. 构建 contents
|
// 1. 构建 contents
|
||||||
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
|
|||||||
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
|
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
|
||||||
|
|
||||||
// 3. 构建 generationConfig
|
// 3. 构建 generationConfig
|
||||||
generationConfig := buildGenerationConfig(claudeReq)
|
reqForGen := claudeReq
|
||||||
|
if requestedThinkingEnabled && !allowDummyThought {
|
||||||
|
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
|
||||||
|
// shallow copy to avoid mutating caller's request
|
||||||
|
clone := *claudeReq
|
||||||
|
clone.Thinking = nil
|
||||||
|
reqForGen = &clone
|
||||||
|
}
|
||||||
|
generationConfig := buildGenerationConfig(reqForGen)
|
||||||
|
|
||||||
// 4. 构建 tools
|
// 4. 构建 tools
|
||||||
tools := buildTools(claudeReq.Tools)
|
tools := buildTools(claudeReq.Tools)
|
||||||
@@ -148,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
|
|||||||
if !hasThoughtPart && len(parts) > 0 {
|
if !hasThoughtPart && len(parts) > 0 {
|
||||||
// 在开头添加 dummy thinking block
|
// 在开头添加 dummy thinking block
|
||||||
parts = append([]GeminiPart{{
|
parts = append([]GeminiPart{{
|
||||||
Text: "Thinking...",
|
Text: "Thinking...",
|
||||||
Thought: true,
|
Thought: true,
|
||||||
|
ThoughtSignature: dummyThoughtSignature,
|
||||||
}}, parts...)
|
}}, parts...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
|
|||||||
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
|
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||||
const dummyThoughtSignature = "skip_thought_signature_validator"
|
const dummyThoughtSignature = "skip_thought_signature_validator"
|
||||||
|
|
||||||
|
// isValidThoughtSignature 验证 thought signature 是否有效
|
||||||
|
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
|
||||||
|
func isValidThoughtSignature(signature string) bool {
|
||||||
|
// 空字符串无效
|
||||||
|
if signature == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
|
||||||
|
// 参考 Claude API 文档和实际观察到的有效 signature
|
||||||
|
if len(signature) < 40 {
|
||||||
|
log.Printf("[Debug] Signature too short: len=%d", len(signature))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否是有效的 base64 字符
|
||||||
|
// base64 字符集: A-Z, a-z, 0-9, +, /, =
|
||||||
|
for i, c := range signature {
|
||||||
|
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
|
||||||
|
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
|
||||||
|
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// buildParts 构建消息的 parts
|
// buildParts 构建消息的 parts
|
||||||
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
|
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
|
||||||
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
|
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
|
||||||
@@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
|||||||
}
|
}
|
||||||
|
|
||||||
case "thinking":
|
case "thinking":
|
||||||
part := GeminiPart{
|
if allowDummyThought {
|
||||||
Text: block.Thinking,
|
// Gemini 模型可以使用 dummy signature
|
||||||
Thought: true,
|
parts = append(parts, GeminiPart{
|
||||||
}
|
Text: block.Thinking,
|
||||||
// 保留原有 signature(Claude 模型需要有效的 signature)
|
Thought: true,
|
||||||
if block.Signature != "" {
|
ThoughtSignature: dummyThoughtSignature,
|
||||||
part.ThoughtSignature = block.Signature
|
})
|
||||||
} else if !allowDummyThought {
|
|
||||||
// Claude 模型需要有效 signature,跳过无 signature 的 thinking block
|
|
||||||
log.Printf("Warning: skipping thinking block without signature for Claude model")
|
|
||||||
continue
|
continue
|
||||||
} else {
|
|
||||||
// Gemini 模型使用 dummy signature
|
|
||||||
part.ThoughtSignature = dummyThoughtSignature
|
|
||||||
}
|
}
|
||||||
parts = append(parts, part)
|
|
||||||
|
// Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。
|
||||||
|
signature := strings.TrimSpace(block.Signature)
|
||||||
|
if signature == "" || signature == dummyThoughtSignature {
|
||||||
|
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !isValidThoughtSignature(signature) {
|
||||||
|
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
|
||||||
|
}
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
Text: block.Thinking,
|
||||||
|
Thought: true,
|
||||||
|
ThoughtSignature: signature,
|
||||||
|
})
|
||||||
|
|
||||||
case "image":
|
case "image":
|
||||||
if block.Source != nil && block.Source.Type == "base64" {
|
if block.Source != nil && block.Source.Type == "base64" {
|
||||||
@@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
|||||||
ID: block.ID,
|
ID: block.ID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// 保留原有 signature,或对 Gemini 模型使用 dummy signature
|
// 只有 Gemini 模型使用 dummy signature
|
||||||
if block.Signature != "" {
|
// Claude 模型不设置 signature(避免验证问题)
|
||||||
part.ThoughtSignature = block.Signature
|
if allowDummyThought {
|
||||||
} else if allowDummyThought {
|
|
||||||
part.ThoughtSignature = dummyThoughtSignature
|
part.ThoughtSignature = dummyThoughtSignature
|
||||||
}
|
}
|
||||||
parts = append(parts, part)
|
parts = append(parts, part)
|
||||||
@@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
|||||||
|
|
||||||
// 普通工具
|
// 普通工具
|
||||||
var funcDecls []GeminiFunctionDecl
|
var funcDecls []GeminiFunctionDecl
|
||||||
for _, tool := range tools {
|
for i, tool := range tools {
|
||||||
// 跳过无效工具名称
|
// 跳过无效工具名称
|
||||||
if tool.Name == "" {
|
if strings.TrimSpace(tool.Name) == "" {
|
||||||
log.Printf("Warning: skipping tool with empty name")
|
log.Printf("Warning: skipping tool with empty name")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
|||||||
var inputSchema map[string]any
|
var inputSchema map[string]any
|
||||||
|
|
||||||
// 检查是否为 custom 类型工具 (MCP)
|
// 检查是否为 custom 类型工具 (MCP)
|
||||||
if tool.Type == "custom" && tool.Custom != nil {
|
if tool.Type == "custom" {
|
||||||
// Custom 格式: 从 custom 字段获取 description 和 input_schema
|
if tool.Custom == nil || tool.Custom.InputSchema == nil {
|
||||||
|
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
description = tool.Custom.Description
|
description = tool.Custom.Description
|
||||||
inputSchema = tool.Custom.InputSchema
|
inputSchema = tool.Custom.InputSchema
|
||||||
|
|
||||||
|
// 调试日志:记录 custom 工具的 schema
|
||||||
|
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
|
||||||
|
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 标准格式: 从顶层字段获取
|
// 标准格式: 从顶层字段获取
|
||||||
description = tool.Description
|
description = tool.Description
|
||||||
@@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
|||||||
|
|
||||||
// 清理 JSON Schema
|
// 清理 JSON Schema
|
||||||
params := cleanJSONSchema(inputSchema)
|
params := cleanJSONSchema(inputSchema)
|
||||||
|
|
||||||
// 为 nil schema 提供默认值
|
// 为 nil schema 提供默认值
|
||||||
if params == nil {
|
if params == nil {
|
||||||
params = map[string]any{
|
params = map[string]any{
|
||||||
@@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 调试日志:记录清理后的 schema
|
||||||
|
if paramsJSON, err := json.Marshal(params); err == nil {
|
||||||
|
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
|
||||||
|
}
|
||||||
|
|
||||||
funcDecls = append(funcDecls, GeminiFunctionDecl{
|
funcDecls = append(funcDecls, GeminiFunctionDecl{
|
||||||
Name: tool.Name,
|
Name: tool.Name,
|
||||||
Description: description,
|
Description: description,
|
||||||
@@ -479,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// excludedSchemaKeys 不支持的 schema 字段
|
// excludedSchemaKeys 不支持的 schema 字段
|
||||||
|
// 基于 Claude API (Vertex AI) 的实际支持情况
|
||||||
|
// 支持: type, description, enum, properties, required, additionalProperties, items
|
||||||
|
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
|
||||||
var excludedSchemaKeys = map[string]bool{
|
var excludedSchemaKeys = map[string]bool{
|
||||||
"$schema": true,
|
// 元 schema 字段
|
||||||
"$id": true,
|
"$schema": true,
|
||||||
"$ref": true,
|
"$id": true,
|
||||||
"additionalProperties": true,
|
"$ref": true,
|
||||||
"minLength": true,
|
|
||||||
"maxLength": true,
|
// 字符串验证(Gemini 不支持)
|
||||||
"minItems": true,
|
"minLength": true,
|
||||||
"maxItems": true,
|
"maxLength": true,
|
||||||
"uniqueItems": true,
|
"pattern": true,
|
||||||
"minimum": true,
|
|
||||||
"maximum": true,
|
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
|
||||||
"exclusiveMinimum": true,
|
"minimum": true,
|
||||||
"exclusiveMaximum": true,
|
"maximum": true,
|
||||||
"pattern": true,
|
"exclusiveMinimum": true,
|
||||||
"format": true,
|
"exclusiveMaximum": true,
|
||||||
"default": true,
|
"multipleOf": true,
|
||||||
"strict": true,
|
|
||||||
"const": true,
|
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
|
||||||
"examples": true,
|
"uniqueItems": true,
|
||||||
"deprecated": true,
|
"minItems": true,
|
||||||
"readOnly": true,
|
"maxItems": true,
|
||||||
"writeOnly": true,
|
|
||||||
"contentMediaType": true,
|
// 组合 schema(Gemini 不支持)
|
||||||
"contentEncoding": true,
|
"oneOf": true,
|
||||||
|
"anyOf": true,
|
||||||
|
"allOf": true,
|
||||||
|
"not": true,
|
||||||
|
"if": true,
|
||||||
|
"then": true,
|
||||||
|
"else": true,
|
||||||
|
"$defs": true,
|
||||||
|
"definitions": true,
|
||||||
|
|
||||||
|
// 对象验证(仅保留 properties/required/additionalProperties)
|
||||||
|
"minProperties": true,
|
||||||
|
"maxProperties": true,
|
||||||
|
"patternProperties": true,
|
||||||
|
"propertyNames": true,
|
||||||
|
"dependencies": true,
|
||||||
|
"dependentSchemas": true,
|
||||||
|
"dependentRequired": true,
|
||||||
|
|
||||||
|
// 其他不支持的字段
|
||||||
|
"default": true,
|
||||||
|
"const": true,
|
||||||
|
"examples": true,
|
||||||
|
"deprecated": true,
|
||||||
|
"readOnly": true,
|
||||||
|
"writeOnly": true,
|
||||||
|
"contentMediaType": true,
|
||||||
|
"contentEncoding": true,
|
||||||
|
|
||||||
|
// Claude 特有字段
|
||||||
|
"strict": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanSchemaValue 递归清理 schema 值
|
// cleanSchemaValue 递归清理 schema 值
|
||||||
@@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
|
||||||
|
if k == "format" {
|
||||||
|
if formatStr, ok := val.(string); ok {
|
||||||
|
// Gemini 只支持 date-time, date, time
|
||||||
|
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
|
||||||
|
result[k] = val
|
||||||
|
}
|
||||||
|
// 其他 format 值直接跳过
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
|
||||||
|
if k == "additionalProperties" {
|
||||||
|
if boolVal, ok := val.(bool); ok {
|
||||||
|
result[k] = boolVal
|
||||||
|
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
|
||||||
|
} else {
|
||||||
|
// 如果是 schema 对象,转换为 false(更安全的默认值)
|
||||||
|
result[k] = false
|
||||||
|
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// 递归清理所有值
|
// 递归清理所有值
|
||||||
result[k] = cleanSchemaValue(val)
|
result[k] = cleanSchemaValue(val)
|
||||||
}
|
}
|
||||||
|
|||||||
179
backend/internal/pkg/antigravity/request_transformer_test.go
Normal file
179
backend/internal/pkg/antigravity/request_transformer_test.go
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||||
|
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
allowDummyThought bool
|
||||||
|
expectedParts int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Claude model - skip thinking block without signature",
|
||||||
|
content: `[
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||||
|
{"type": "text", "text": "World"}
|
||||||
|
]`,
|
||||||
|
allowDummyThought: false,
|
||||||
|
expectedParts: 2, // 只有两个text block
|
||||||
|
description: "Claude模型应该跳过无signature的thinking block",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Claude model - keep thinking block with signature",
|
||||||
|
content: `[
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
|
||||||
|
{"type": "text", "text": "World"}
|
||||||
|
]`,
|
||||||
|
allowDummyThought: false,
|
||||||
|
expectedParts: 3, // 三个block都保留
|
||||||
|
description: "Claude模型应该保留有signature的thinking block",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gemini model - use dummy signature",
|
||||||
|
content: `[
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||||
|
{"type": "text", "text": "World"}
|
||||||
|
]`,
|
||||||
|
allowDummyThought: true,
|
||||||
|
expectedParts: 3, // 三个block都保留,thinking使用dummy signature
|
||||||
|
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
toolIDToName := make(map[string]string)
|
||||||
|
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildParts() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) != tt.expectedParts {
|
||||||
|
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
|
||||||
|
func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tools []ClaudeTool
|
||||||
|
expectedLen int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Standard tool format",
|
||||||
|
tools: []ClaudeTool{
|
||||||
|
{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather information",
|
||||||
|
InputSchema: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"location": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLen: 1,
|
||||||
|
description: "标准工具格式应该正常转换",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Custom type tool (MCP format)",
|
||||||
|
tools: []ClaudeTool{
|
||||||
|
{
|
||||||
|
Type: "custom",
|
||||||
|
Name: "mcp_tool",
|
||||||
|
Custom: &CustomToolSpec{
|
||||||
|
Description: "MCP tool description",
|
||||||
|
InputSchema: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"param": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLen: 1,
|
||||||
|
description: "Custom类型工具应该从Custom字段读取description和input_schema",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed standard and custom tools",
|
||||||
|
tools: []ClaudeTool{
|
||||||
|
{
|
||||||
|
Name: "standard_tool",
|
||||||
|
Description: "Standard tool",
|
||||||
|
InputSchema: map[string]any{"type": "object"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "custom",
|
||||||
|
Name: "custom_tool",
|
||||||
|
Custom: &CustomToolSpec{
|
||||||
|
Description: "Custom tool",
|
||||||
|
InputSchema: map[string]any{"type": "object"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations
|
||||||
|
description: "混合标准和custom工具应该都能正确转换",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid custom tool - nil Custom field",
|
||||||
|
tools: []ClaudeTool{
|
||||||
|
{
|
||||||
|
Type: "custom",
|
||||||
|
Name: "invalid_custom",
|
||||||
|
// Custom 为 nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLen: 0, // 应该被跳过
|
||||||
|
description: "Custom字段为nil的custom工具应该被跳过",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid custom tool - nil InputSchema",
|
||||||
|
tools: []ClaudeTool{
|
||||||
|
{
|
||||||
|
Type: "custom",
|
||||||
|
Name: "invalid_custom",
|
||||||
|
Custom: &CustomToolSpec{
|
||||||
|
Description: "Invalid",
|
||||||
|
// InputSchema 为 nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLen: 0, // 应该被跳过
|
||||||
|
description: "InputSchema为nil的custom工具应该被跳过",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := buildTools(tt.tools)
|
||||||
|
|
||||||
|
if len(result) != tt.expectedLen {
|
||||||
|
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证function declarations存在
|
||||||
|
if len(result) > 0 && result[0].FunctionDeclarations != nil {
|
||||||
|
if len(result[0].FunctionDeclarations) != len(tt.tools) {
|
||||||
|
t.Errorf("%s: got %d function declarations, want %d",
|
||||||
|
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
|
|||||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
|
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||||
|
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||||
|
|
||||||
|
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||||
|
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||||
|
|
||||||
// Claude Code 客户端默认请求头
|
// Claude Code 客户端默认请求头
|
||||||
var DefaultHeaders = map[string]string{
|
var DefaultHeaders = map[string]string{
|
||||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||||
|
|||||||
157
backend/internal/pkg/geminicli/drive_client.go
Normal file
157
backend/internal/pkg/geminicli/drive_client.go
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
package geminicli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DriveStorageInfo represents Google Drive storage quota information
|
||||||
|
type DriveStorageInfo struct {
|
||||||
|
Limit int64 `json:"limit"` // Storage limit in bytes
|
||||||
|
Usage int64 `json:"usage"` // Current usage in bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// DriveClient interface for Google Drive API operations
|
||||||
|
type DriveClient interface {
|
||||||
|
GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type driveClient struct{}
|
||||||
|
|
||||||
|
// NewDriveClient creates a new Drive API client
|
||||||
|
func NewDriveClient() DriveClient {
|
||||||
|
return &driveClient{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStorageQuota fetches storage quota from Google Drive API
|
||||||
|
func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error) {
|
||||||
|
const driveAPIURL = "https://www.googleapis.com/drive/v3/about?fields=storageQuota"
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", driveAPIURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
|
||||||
|
// Get HTTP client with proxy support
|
||||||
|
client, err := httpclient.GetClient(httpclient.Options{
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sleepWithContext := func(d time.Duration) error {
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
|
||||||
|
var resp *http.Response
|
||||||
|
maxRetries := 3
|
||||||
|
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
// Network error retry
|
||||||
|
if attempt < maxRetries-1 {
|
||||||
|
backoff := time.Duration(1<<uint(attempt)) * time.Second
|
||||||
|
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
|
||||||
|
if err := sleepWithContext(backoff + jitter); err != nil {
|
||||||
|
return nil, fmt.Errorf("request cancelled: %w", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry 429, 500, 502, 503 with exponential backoff + jitter
|
||||||
|
if (resp.StatusCode == http.StatusTooManyRequests ||
|
||||||
|
resp.StatusCode == http.StatusInternalServerError ||
|
||||||
|
resp.StatusCode == http.StatusBadGateway ||
|
||||||
|
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
|
||||||
|
if err := func() error {
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
backoff := time.Duration(1<<uint(attempt)) * time.Second
|
||||||
|
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
|
||||||
|
return sleepWithContext(backoff + jitter)
|
||||||
|
}(); err != nil {
|
||||||
|
return nil, fmt.Errorf("request cancelled: %w", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp == nil {
|
||||||
|
return nil, fmt.Errorf("request failed: no response received")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
statusText := http.StatusText(resp.StatusCode)
|
||||||
|
if statusText == "" {
|
||||||
|
statusText = resp.Status
|
||||||
|
}
|
||||||
|
fmt.Printf("[DriveClient] Drive API error: status=%d, msg=%s\n", resp.StatusCode, statusText)
|
||||||
|
// 只返回通用错误
|
||||||
|
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
// Parse response
|
||||||
|
var result struct {
|
||||||
|
StorageQuota struct {
|
||||||
|
Limit string `json:"limit"` // Can be string or number
|
||||||
|
Usage string `json:"usage"`
|
||||||
|
} `json:"storageQuota"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse limit and usage (handle both string and number formats)
|
||||||
|
var limit, usage int64
|
||||||
|
if result.StorageQuota.Limit != "" {
|
||||||
|
if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
|
||||||
|
limit = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if result.StorageQuota.Usage != "" {
|
||||||
|
if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
|
||||||
|
usage = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DriveStorageInfo{
|
||||||
|
Limit: limit,
|
||||||
|
Usage: usage,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
18
backend/internal/pkg/geminicli/drive_client_test.go
Normal file
18
backend/internal/pkg/geminicli/drive_client_test.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package geminicli
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDriveStorageInfo(t *testing.T) {
|
||||||
|
// 测试 DriveStorageInfo 结构体
|
||||||
|
info := &DriveStorageInfo{
|
||||||
|
Limit: 100 * 1024 * 1024 * 1024, // 100GB
|
||||||
|
Usage: 50 * 1024 * 1024 * 1024, // 50GB
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Limit != 100*1024*1024*1024 {
|
||||||
|
t.Errorf("Expected limit 100GB, got %d", info.Limit)
|
||||||
|
}
|
||||||
|
if info.Usage != 50*1024*1024*1024 {
|
||||||
|
t.Errorf("Expected usage 50GB, got %d", info.Usage)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -124,6 +124,90 @@ func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Acc
|
|||||||
return &accounts[0], nil
|
return &accounts[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return []*service.Account{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// De-duplicate while preserving order of first occurrence.
|
||||||
|
uniqueIDs := make([]int64, 0, len(ids))
|
||||||
|
seen := make(map[int64]struct{}, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
if id <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
uniqueIDs = append(uniqueIDs, id)
|
||||||
|
}
|
||||||
|
if len(uniqueIDs) == 0 {
|
||||||
|
return []*service.Account{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
entAccounts, err := r.client.Account.
|
||||||
|
Query().
|
||||||
|
Where(dbaccount.IDIn(uniqueIDs...)).
|
||||||
|
WithProxy().
|
||||||
|
All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(entAccounts) == 0 {
|
||||||
|
return []*service.Account{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accountIDs := make([]int64, 0, len(entAccounts))
|
||||||
|
entByID := make(map[int64]*dbent.Account, len(entAccounts))
|
||||||
|
for _, acc := range entAccounts {
|
||||||
|
entByID[acc.ID] = acc
|
||||||
|
accountIDs = append(accountIDs, acc.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
outByID := make(map[int64]*service.Account, len(entAccounts))
|
||||||
|
for _, entAcc := range entAccounts {
|
||||||
|
out := accountEntityToService(entAcc)
|
||||||
|
if out == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer the preloaded proxy edge when available.
|
||||||
|
if entAcc.Edges.Proxy != nil {
|
||||||
|
out.Proxy = proxyEntityToService(entAcc.Edges.Proxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
if groups, ok := groupsByAccount[entAcc.ID]; ok {
|
||||||
|
out.Groups = groups
|
||||||
|
}
|
||||||
|
if groupIDs, ok := groupIDsByAccount[entAcc.ID]; ok {
|
||||||
|
out.GroupIDs = groupIDs
|
||||||
|
}
|
||||||
|
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
|
||||||
|
out.AccountGroups = ags
|
||||||
|
}
|
||||||
|
outByID[entAcc.ID] = out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve input order (first occurrence), and ignore missing IDs.
|
||||||
|
out := make([]*service.Account, 0, len(uniqueIDs))
|
||||||
|
for _, id := range uniqueIDs {
|
||||||
|
if _, ok := entByID[id]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if acc, ok := outByID[id]; ok && acc != nil {
|
||||||
|
out = append(out, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ExistsByID 检查指定 ID 的账号是否存在。
|
// ExistsByID 检查指定 ID 的账号是否存在。
|
||||||
// 相比 GetByID,此方法性能更优,因为:
|
// 相比 GetByID,此方法性能更优,因为:
|
||||||
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值
|
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值
|
||||||
|
|||||||
@@ -294,7 +294,6 @@ func userEntityToService(u *dbent.User) *service.User {
|
|||||||
ID: u.ID,
|
ID: u.ID,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
Username: u.Username,
|
Username: u.Username,
|
||||||
Wechat: u.Wechat,
|
|
||||||
Notes: u.Notes,
|
Notes: u.Notes,
|
||||||
PasswordHash: u.PasswordHash,
|
PasswordHash: u.PasswordHash,
|
||||||
Role: u.Role,
|
Role: u.Role,
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
@@ -27,6 +29,8 @@ const (
|
|||||||
userSlotKeyPrefix = "concurrency:user:"
|
userSlotKeyPrefix = "concurrency:user:"
|
||||||
// 等待队列计数器格式: concurrency:wait:{userID}
|
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||||
waitQueueKeyPrefix = "concurrency:wait:"
|
waitQueueKeyPrefix = "concurrency:wait:"
|
||||||
|
// 账号级等待队列计数器格式: wait:account:{accountID}
|
||||||
|
accountWaitKeyPrefix = "wait:account:"
|
||||||
|
|
||||||
// 默认槽位过期时间(分钟),可通过配置覆盖
|
// 默认槽位过期时间(分钟),可通过配置覆盖
|
||||||
defaultSlotTTLMinutes = 15
|
defaultSlotTTLMinutes = 15
|
||||||
@@ -112,33 +116,112 @@ var (
|
|||||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
end
|
end
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
|
|
||||||
|
// incrementAccountWaitScript - account-level wait queue count
|
||||||
|
incrementAccountWaitScript = redis.NewScript(`
|
||||||
|
local current = redis.call('GET', KEYS[1])
|
||||||
|
if current == false then
|
||||||
|
current = 0
|
||||||
|
else
|
||||||
|
current = tonumber(current)
|
||||||
|
end
|
||||||
|
|
||||||
|
if current >= tonumber(ARGV[1]) then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
local newVal = redis.call('INCR', KEYS[1])
|
||||||
|
|
||||||
|
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||||
|
if newVal == 1 then
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
|
end
|
||||||
|
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
|
||||||
// decrementWaitScript - same as before
|
// decrementWaitScript - same as before
|
||||||
decrementWaitScript = redis.NewScript(`
|
decrementWaitScript = redis.NewScript(`
|
||||||
local current = redis.call('GET', KEYS[1])
|
local current = redis.call('GET', KEYS[1])
|
||||||
if current ~= false and tonumber(current) > 0 then
|
if current ~= false and tonumber(current) > 0 then
|
||||||
redis.call('DECR', KEYS[1])
|
redis.call('DECR', KEYS[1])
|
||||||
end
|
end
|
||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
|
|
||||||
|
// getAccountsLoadBatchScript - batch load query (read-only)
|
||||||
|
// ARGV[1] = slot TTL (seconds, retained for compatibility)
|
||||||
|
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
|
||||||
|
getAccountsLoadBatchScript = redis.NewScript(`
|
||||||
|
local result = {}
|
||||||
|
|
||||||
|
local i = 2
|
||||||
|
while i <= #ARGV do
|
||||||
|
local accountID = ARGV[i]
|
||||||
|
local maxConcurrency = tonumber(ARGV[i + 1])
|
||||||
|
|
||||||
|
local slotKey = 'concurrency:account:' .. accountID
|
||||||
|
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||||||
|
|
||||||
|
local waitKey = 'wait:account:' .. accountID
|
||||||
|
local waitingCount = redis.call('GET', waitKey)
|
||||||
|
if waitingCount == false then
|
||||||
|
waitingCount = 0
|
||||||
|
else
|
||||||
|
waitingCount = tonumber(waitingCount)
|
||||||
|
end
|
||||||
|
|
||||||
|
local loadRate = 0
|
||||||
|
if maxConcurrency > 0 then
|
||||||
|
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
|
||||||
|
end
|
||||||
|
|
||||||
|
table.insert(result, accountID)
|
||||||
|
table.insert(result, currentConcurrency)
|
||||||
|
table.insert(result, waitingCount)
|
||||||
|
table.insert(result, loadRate)
|
||||||
|
|
||||||
|
i = i + 2
|
||||||
|
end
|
||||||
|
|
||||||
|
return result
|
||||||
|
`)
|
||||||
|
|
||||||
|
// cleanupExpiredSlotsScript - remove expired slots
|
||||||
|
// KEYS[1] = concurrency:account:{accountID}
|
||||||
|
// ARGV[1] = TTL (seconds)
|
||||||
|
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||||
|
local key = KEYS[1]
|
||||||
|
local ttl = tonumber(ARGV[1])
|
||||||
|
local timeResult = redis.call('TIME')
|
||||||
|
local now = tonumber(timeResult[1])
|
||||||
|
local expireBefore = now - ttl
|
||||||
|
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||||
|
`)
|
||||||
)
|
)
|
||||||
|
|
||||||
type concurrencyCache struct {
|
type concurrencyCache struct {
|
||||||
rdb *redis.Client
|
rdb *redis.Client
|
||||||
slotTTLSeconds int // 槽位过期时间(秒)
|
slotTTLSeconds int // 槽位过期时间(秒)
|
||||||
|
waitQueueTTLSeconds int // 等待队列过期时间(秒)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConcurrencyCache 创建并发控制缓存
|
// NewConcurrencyCache 创建并发控制缓存
|
||||||
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
|
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
|
||||||
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache {
|
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL
|
||||||
|
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
|
||||||
if slotTTLMinutes <= 0 {
|
if slotTTLMinutes <= 0 {
|
||||||
slotTTLMinutes = defaultSlotTTLMinutes
|
slotTTLMinutes = defaultSlotTTLMinutes
|
||||||
}
|
}
|
||||||
|
if waitQueueTTLSeconds <= 0 {
|
||||||
|
waitQueueTTLSeconds = slotTTLMinutes * 60
|
||||||
|
}
|
||||||
return &concurrencyCache{
|
return &concurrencyCache{
|
||||||
rdb: rdb,
|
rdb: rdb,
|
||||||
slotTTLSeconds: slotTTLMinutes * 60,
|
slotTTLSeconds: slotTTLMinutes * 60,
|
||||||
|
waitQueueTTLSeconds: waitQueueTTLSeconds,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,6 +238,10 @@ func waitQueueKey(userID int64) string {
|
|||||||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func accountWaitKey(accountID int64) string {
|
||||||
|
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// Account slot operations
|
// Account slot operations
|
||||||
|
|
||||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
@@ -225,3 +312,75 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
|
|||||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Account wait queue operations
|
||||||
|
|
||||||
|
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
|
key := accountWaitKey(accountID)
|
||||||
|
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return result == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||||
|
key := accountWaitKey(accountID)
|
||||||
|
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
key := accountWaitKey(accountID)
|
||||||
|
val, err := c.rdb.Get(ctx, key).Int()
|
||||||
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if errors.Is(err, redis.Nil) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||||
|
if len(accounts) == 0 {
|
||||||
|
return map[int64]*service.AccountLoadInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []any{c.slotTTLSeconds}
|
||||||
|
for _, acc := range accounts {
|
||||||
|
args = append(args, acc.ID, acc.MaxConcurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
loadMap := make(map[int64]*service.AccountLoadInfo)
|
||||||
|
for i := 0; i < len(result); i += 4 {
|
||||||
|
if i+3 >= len(result) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
||||||
|
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
||||||
|
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
||||||
|
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
||||||
|
|
||||||
|
loadMap[accountID] = &service.AccountLoadInfo{
|
||||||
|
AccountID: accountID,
|
||||||
|
CurrentConcurrency: currentConcurrency,
|
||||||
|
WaitingCount: waitingCount,
|
||||||
|
LoadRate: loadRate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return loadMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||||
|
key := accountSlotKey(accountID)
|
||||||
|
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) {
|
|||||||
_ = rdb.Close()
|
_ = rdb.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache)
|
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
for _, size := range []int{10, 100, 1000} {
|
for _, size := range []int{10, 100, 1000} {
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct {
|
|||||||
|
|
||||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||||
s.IntegrationRedisSuite.SetupTest()
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes)
|
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||||
@@ -218,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
|||||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||||
|
accountID := int64(30)
|
||||||
|
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||||
|
|
||||||
|
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||||
|
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||||
|
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||||
|
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
|
||||||
|
require.False(s.T(), ok, "expected account wait increment over max to fail")
|
||||||
|
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL account waitKey")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
|
||||||
|
|
||||||
|
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||||
|
if !errors.Is(err, redis.Nil) {
|
||||||
|
require.NoError(s.T(), err, "Get waitKey")
|
||||||
|
}
|
||||||
|
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
|
||||||
|
accountID := int64(301)
|
||||||
|
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
|
||||||
|
|
||||||
|
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||||
|
if !errors.Is(err, redis.Nil) {
|
||||||
|
require.NoError(s.T(), err, "Get waitKey")
|
||||||
|
}
|
||||||
|
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||||
// When no slots exist, GetAccountConcurrency should return 0
|
// When no slots exist, GetAccountConcurrency should return 0
|
||||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
|
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
|
||||||
@@ -232,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
|
|||||||
require.Equal(s.T(), 0, cur)
|
require.Equal(s.T(), 0, cur)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
|
||||||
|
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
|
||||||
|
// Setup: Create accounts with different load states
|
||||||
|
account1 := int64(100)
|
||||||
|
account2 := int64(101)
|
||||||
|
account3 := int64(102)
|
||||||
|
|
||||||
|
// Account 1: 2/3 slots used, 1 waiting
|
||||||
|
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
// Account 2: 1/2 slots used, 0 waiting
|
||||||
|
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
// Account 3: 0/1 slots used, 0 waiting (idle)
|
||||||
|
|
||||||
|
// Query batch load
|
||||||
|
accounts := []service.AccountWithConcurrency{
|
||||||
|
{ID: account1, MaxConcurrency: 3},
|
||||||
|
{ID: account2, MaxConcurrency: 2},
|
||||||
|
{ID: account3, MaxConcurrency: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Len(s.T(), loadMap, 3)
|
||||||
|
|
||||||
|
// Verify account1: (2 + 1) / 3 = 100%
|
||||||
|
load1 := loadMap[account1]
|
||||||
|
require.NotNil(s.T(), load1)
|
||||||
|
require.Equal(s.T(), account1, load1.AccountID)
|
||||||
|
require.Equal(s.T(), 2, load1.CurrentConcurrency)
|
||||||
|
require.Equal(s.T(), 1, load1.WaitingCount)
|
||||||
|
require.Equal(s.T(), 100, load1.LoadRate)
|
||||||
|
|
||||||
|
// Verify account2: (1 + 0) / 2 = 50%
|
||||||
|
load2 := loadMap[account2]
|
||||||
|
require.NotNil(s.T(), load2)
|
||||||
|
require.Equal(s.T(), account2, load2.AccountID)
|
||||||
|
require.Equal(s.T(), 1, load2.CurrentConcurrency)
|
||||||
|
require.Equal(s.T(), 0, load2.WaitingCount)
|
||||||
|
require.Equal(s.T(), 50, load2.LoadRate)
|
||||||
|
|
||||||
|
// Verify account3: (0 + 0) / 1 = 0%
|
||||||
|
load3 := loadMap[account3]
|
||||||
|
require.NotNil(s.T(), load3)
|
||||||
|
require.Equal(s.T(), account3, load3.AccountID)
|
||||||
|
require.Equal(s.T(), 0, load3.CurrentConcurrency)
|
||||||
|
require.Equal(s.T(), 0, load3.WaitingCount)
|
||||||
|
require.Equal(s.T(), 0, load3.LoadRate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
|
||||||
|
// Test with empty account list
|
||||||
|
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Empty(s.T(), loadMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
|
||||||
|
accountID := int64(200)
|
||||||
|
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||||
|
|
||||||
|
// Acquire 3 slots
|
||||||
|
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
// Verify 3 slots exist
|
||||||
|
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), 3, cur)
|
||||||
|
|
||||||
|
// Manually set old timestamps for req1 and req2 (simulate expired slots)
|
||||||
|
now := time.Now().Unix()
|
||||||
|
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
|
||||||
|
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
|
||||||
|
// Run cleanup
|
||||||
|
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
|
||||||
|
// Verify only 1 slot remains (req3)
|
||||||
|
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), 1, cur)
|
||||||
|
|
||||||
|
// Verify req3 still exists
|
||||||
|
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Len(s.T(), members, 1)
|
||||||
|
require.Equal(s.T(), "req3", members[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||||
|
accountID := int64(201)
|
||||||
|
|
||||||
|
// Acquire 2 fresh slots
|
||||||
|
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
// Run cleanup (should not remove anything)
|
||||||
|
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
|
||||||
|
// Verify both slots still exist
|
||||||
|
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), 2, cur)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *servic
|
|||||||
SetBalance(u.Balance).
|
SetBalance(u.Balance).
|
||||||
SetConcurrency(u.Concurrency).
|
SetConcurrency(u.Concurrency).
|
||||||
SetUsername(u.Username).
|
SetUsername(u.Username).
|
||||||
SetWechat(u.Wechat).
|
|
||||||
SetNotes(u.Notes)
|
SetNotes(u.Notes)
|
||||||
if !u.CreatedAt.IsZero() {
|
if !u.CreatedAt.IsZero() {
|
||||||
create.SetCreatedAt(u.CreatedAt)
|
create.SetCreatedAt(u.CreatedAt)
|
||||||
|
|||||||
@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
|||||||
if existing != checksum {
|
if existing != checksum {
|
||||||
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
|
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
|
||||||
// 正确的做法是创建新的迁移文件来进行变更。
|
// 正确的做法是创建新的迁移文件来进行变更。
|
||||||
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum)
|
return fmt.Errorf(
|
||||||
|
"migration %s checksum mismatch (db=%s file=%s)\n"+
|
||||||
|
"This means the migration file was modified after being applied to the database.\n"+
|
||||||
|
"Solutions:\n"+
|
||||||
|
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
|
||||||
|
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
|
||||||
|
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
|
||||||
|
name, existing, checksum, name, name,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
continue // 迁移已应用且校验和匹配,跳过
|
continue // 迁移已应用且校验和匹配,跳过
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
|||||||
|
|
||||||
// users: columns required by repository queries
|
// users: columns required by repository queries
|
||||||
requireColumn(t, tx, "users", "username", "character varying", 100, false)
|
requireColumn(t, tx, "users", "username", "character varying", 100, false)
|
||||||
requireColumn(t, tx, "users", "wechat", "character varying", 100, false)
|
|
||||||
requireColumn(t, tx, "users", "notes", "text", 0, false)
|
requireColumn(t, tx, "users", "notes", "text", 0, false)
|
||||||
|
|
||||||
// accounts: schedulable and rate-limit fields
|
// accounts: schedulable and rate-limit fields
|
||||||
|
|||||||
385
backend/internal/repository/user_attribute_repo.go
Normal file
385
backend/internal/repository/user_attribute_repo.go
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeDefinitionRepository implementation
|
||||||
|
type userAttributeDefinitionRepository struct {
|
||||||
|
client *dbent.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserAttributeDefinitionRepository creates a new repository instance
|
||||||
|
func NewUserAttributeDefinitionRepository(client *dbent.Client) service.UserAttributeDefinitionRepository {
|
||||||
|
return &userAttributeDefinitionRepository{client: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeDefinitionRepository) Create(ctx context.Context, def *service.UserAttributeDefinition) error {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|
||||||
|
created, err := client.UserAttributeDefinition.Create().
|
||||||
|
SetKey(def.Key).
|
||||||
|
SetName(def.Name).
|
||||||
|
SetDescription(def.Description).
|
||||||
|
SetType(string(def.Type)).
|
||||||
|
SetOptions(toEntOptions(def.Options)).
|
||||||
|
SetRequired(def.Required).
|
||||||
|
SetValidation(toEntValidation(def.Validation)).
|
||||||
|
SetPlaceholder(def.Placeholder).
|
||||||
|
SetEnabled(def.Enabled).
|
||||||
|
Save(ctx)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return translatePersistenceError(err, nil, service.ErrAttributeKeyExists)
|
||||||
|
}
|
||||||
|
|
||||||
|
def.ID = created.ID
|
||||||
|
def.DisplayOrder = created.DisplayOrder
|
||||||
|
def.CreatedAt = created.CreatedAt
|
||||||
|
def.UpdatedAt = created.UpdatedAt
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeDefinitionRepository) GetByID(ctx context.Context, id int64) (*service.UserAttributeDefinition, error) {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|
||||||
|
e, err := client.UserAttributeDefinition.Query().
|
||||||
|
Where(userattributedefinition.IDEQ(id)).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
|
||||||
|
}
|
||||||
|
return defEntityToService(e), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeDefinitionRepository) GetByKey(ctx context.Context, key string) (*service.UserAttributeDefinition, error) {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|
||||||
|
e, err := client.UserAttributeDefinition.Query().
|
||||||
|
Where(userattributedefinition.KeyEQ(key)).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
|
||||||
|
}
|
||||||
|
return defEntityToService(e), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeDefinitionRepository) Update(ctx context.Context, def *service.UserAttributeDefinition) error {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|
||||||
|
updated, err := client.UserAttributeDefinition.UpdateOneID(def.ID).
|
||||||
|
SetName(def.Name).
|
||||||
|
SetDescription(def.Description).
|
||||||
|
SetType(string(def.Type)).
|
||||||
|
SetOptions(toEntOptions(def.Options)).
|
||||||
|
SetRequired(def.Required).
|
||||||
|
SetValidation(toEntValidation(def.Validation)).
|
||||||
|
SetPlaceholder(def.Placeholder).
|
||||||
|
SetDisplayOrder(def.DisplayOrder).
|
||||||
|
SetEnabled(def.Enabled).
|
||||||
|
Save(ctx)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, service.ErrAttributeKeyExists)
|
||||||
|
}
|
||||||
|
|
||||||
|
def.UpdatedAt = updated.UpdatedAt
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeDefinitionRepository) Delete(ctx context.Context, id int64) error {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|
||||||
|
_, err := client.UserAttributeDefinition.Delete().
|
||||||
|
Where(userattributedefinition.IDEQ(id)).
|
||||||
|
Exec(ctx)
|
||||||
|
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeDefinitionRepository) List(ctx context.Context, enabledOnly bool) ([]service.UserAttributeDefinition, error) {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|
||||||
|
q := client.UserAttributeDefinition.Query()
|
||||||
|
if enabledOnly {
|
||||||
|
q = q.Where(userattributedefinition.EnabledEQ(true))
|
||||||
|
}
|
||||||
|
|
||||||
|
entities, err := q.Order(dbent.Asc(userattributedefinition.FieldDisplayOrder)).All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]service.UserAttributeDefinition, 0, len(entities))
|
||||||
|
for _, e := range entities {
|
||||||
|
result = append(result, *defEntityToService(e))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeDefinitionRepository) UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error {
|
||||||
|
tx, err := r.client.Tx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
for id, order := range orders {
|
||||||
|
if _, err := tx.UserAttributeDefinition.UpdateOneID(id).
|
||||||
|
SetDisplayOrder(order).
|
||||||
|
Save(ctx); err != nil {
|
||||||
|
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeDefinitionRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
return client.UserAttributeDefinition.Query().
|
||||||
|
Where(userattributedefinition.KeyEQ(key)).
|
||||||
|
Exist(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueRepository implementation
|
||||||
|
type userAttributeValueRepository struct {
|
||||||
|
client *dbent.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserAttributeValueRepository creates a new repository instance
|
||||||
|
func NewUserAttributeValueRepository(client *dbent.Client) service.UserAttributeValueRepository {
|
||||||
|
return &userAttributeValueRepository{client: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeValueRepository) GetByUserID(ctx context.Context, userID int64) ([]service.UserAttributeValue, error) {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|
||||||
|
entities, err := client.UserAttributeValue.Query().
|
||||||
|
Where(userattributevalue.UserIDEQ(userID)).
|
||||||
|
All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]service.UserAttributeValue, 0, len(entities))
|
||||||
|
for _, e := range entities {
|
||||||
|
result = append(result, service.UserAttributeValue{
|
||||||
|
ID: e.ID,
|
||||||
|
UserID: e.UserID,
|
||||||
|
AttributeID: e.AttributeID,
|
||||||
|
Value: e.Value,
|
||||||
|
CreatedAt: e.CreatedAt,
|
||||||
|
UpdatedAt: e.UpdatedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeValueRepository) GetByUserIDs(ctx context.Context, userIDs []int64) ([]service.UserAttributeValue, error) {
|
||||||
|
if len(userIDs) == 0 {
|
||||||
|
return []service.UserAttributeValue{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|
||||||
|
entities, err := client.UserAttributeValue.Query().
|
||||||
|
Where(userattributevalue.UserIDIn(userIDs...)).
|
||||||
|
All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]service.UserAttributeValue, 0, len(entities))
|
||||||
|
for _, e := range entities {
|
||||||
|
result = append(result, service.UserAttributeValue{
|
||||||
|
ID: e.ID,
|
||||||
|
UserID: e.UserID,
|
||||||
|
AttributeID: e.AttributeID,
|
||||||
|
Value: e.Value,
|
||||||
|
CreatedAt: e.CreatedAt,
|
||||||
|
UpdatedAt: e.UpdatedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeValueRepository) UpsertBatch(ctx context.Context, userID int64, inputs []service.UpdateUserAttributeInput) error {
|
||||||
|
if len(inputs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := r.client.Tx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
for _, input := range inputs {
|
||||||
|
// Use upsert (ON CONFLICT DO UPDATE)
|
||||||
|
err := tx.UserAttributeValue.Create().
|
||||||
|
SetUserID(userID).
|
||||||
|
SetAttributeID(input.AttributeID).
|
||||||
|
SetValue(input.Value).
|
||||||
|
OnConflictColumns(userattributevalue.FieldUserID, userattributevalue.FieldAttributeID).
|
||||||
|
UpdateValue().
|
||||||
|
UpdateUpdatedAt().
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeValueRepository) DeleteByAttributeID(ctx context.Context, attributeID int64) error {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|
||||||
|
_, err := client.UserAttributeValue.Delete().
|
||||||
|
Where(userattributevalue.AttributeIDEQ(attributeID)).
|
||||||
|
Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userAttributeValueRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|
||||||
|
_, err := client.UserAttributeValue.Delete().
|
||||||
|
Where(userattributevalue.UserIDEQ(userID)).
|
||||||
|
Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for entity to service conversion
|
||||||
|
func defEntityToService(e *dbent.UserAttributeDefinition) *service.UserAttributeDefinition {
|
||||||
|
if e == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &service.UserAttributeDefinition{
|
||||||
|
ID: e.ID,
|
||||||
|
Key: e.Key,
|
||||||
|
Name: e.Name,
|
||||||
|
Description: e.Description,
|
||||||
|
Type: service.UserAttributeType(e.Type),
|
||||||
|
Options: toServiceOptions(e.Options),
|
||||||
|
Required: e.Required,
|
||||||
|
Validation: toServiceValidation(e.Validation),
|
||||||
|
Placeholder: e.Placeholder,
|
||||||
|
DisplayOrder: e.DisplayOrder,
|
||||||
|
Enabled: e.Enabled,
|
||||||
|
CreatedAt: e.CreatedAt,
|
||||||
|
UpdatedAt: e.UpdatedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type conversion helpers (map types <-> service types)
|
||||||
|
func toEntOptions(opts []service.UserAttributeOption) []map[string]any {
|
||||||
|
if opts == nil {
|
||||||
|
return []map[string]any{}
|
||||||
|
}
|
||||||
|
result := make([]map[string]any, len(opts))
|
||||||
|
for i, o := range opts {
|
||||||
|
result[i] = map[string]any{"value": o.Value, "label": o.Label}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func toServiceOptions(opts []map[string]any) []service.UserAttributeOption {
|
||||||
|
if opts == nil {
|
||||||
|
return []service.UserAttributeOption{}
|
||||||
|
}
|
||||||
|
result := make([]service.UserAttributeOption, len(opts))
|
||||||
|
for i, o := range opts {
|
||||||
|
result[i] = service.UserAttributeOption{
|
||||||
|
Value: getString(o, "value"),
|
||||||
|
Label: getString(o, "label"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func toEntValidation(v service.UserAttributeValidation) map[string]any {
|
||||||
|
result := map[string]any{}
|
||||||
|
if v.MinLength != nil {
|
||||||
|
result["min_length"] = *v.MinLength
|
||||||
|
}
|
||||||
|
if v.MaxLength != nil {
|
||||||
|
result["max_length"] = *v.MaxLength
|
||||||
|
}
|
||||||
|
if v.Min != nil {
|
||||||
|
result["min"] = *v.Min
|
||||||
|
}
|
||||||
|
if v.Max != nil {
|
||||||
|
result["max"] = *v.Max
|
||||||
|
}
|
||||||
|
if v.Pattern != nil {
|
||||||
|
result["pattern"] = *v.Pattern
|
||||||
|
}
|
||||||
|
if v.Message != nil {
|
||||||
|
result["message"] = *v.Message
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func toServiceValidation(v map[string]any) service.UserAttributeValidation {
|
||||||
|
result := service.UserAttributeValidation{}
|
||||||
|
if val := getInt(v, "min_length"); val != nil {
|
||||||
|
result.MinLength = val
|
||||||
|
}
|
||||||
|
if val := getInt(v, "max_length"); val != nil {
|
||||||
|
result.MaxLength = val
|
||||||
|
}
|
||||||
|
if val := getInt(v, "min"); val != nil {
|
||||||
|
result.Min = val
|
||||||
|
}
|
||||||
|
if val := getInt(v, "max"); val != nil {
|
||||||
|
result.Max = val
|
||||||
|
}
|
||||||
|
if val := getStringPtr(v, "pattern"); val != nil {
|
||||||
|
result.Pattern = val
|
||||||
|
}
|
||||||
|
if val := getStringPtr(v, "message"); val != nil {
|
||||||
|
result.Message = val
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for type conversion
|
||||||
|
func getString(m map[string]any, key string) string {
|
||||||
|
if v, ok := m[key]; ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStringPtr(m map[string]any, key string) *string {
|
||||||
|
if v, ok := m[key]; ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInt(m map[string]any, key string) *int {
|
||||||
|
if v, ok := m[key]; ok {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case int:
|
||||||
|
return &n
|
||||||
|
case int64:
|
||||||
|
i := int(n)
|
||||||
|
return &i
|
||||||
|
case float64:
|
||||||
|
i := int(n)
|
||||||
|
return &i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -50,7 +51,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
|||||||
created, err := txClient.User.Create().
|
created, err := txClient.User.Create().
|
||||||
SetEmail(userIn.Email).
|
SetEmail(userIn.Email).
|
||||||
SetUsername(userIn.Username).
|
SetUsername(userIn.Username).
|
||||||
SetWechat(userIn.Wechat).
|
|
||||||
SetNotes(userIn.Notes).
|
SetNotes(userIn.Notes).
|
||||||
SetPasswordHash(userIn.PasswordHash).
|
SetPasswordHash(userIn.PasswordHash).
|
||||||
SetRole(userIn.Role).
|
SetRole(userIn.Role).
|
||||||
@@ -133,7 +133,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
||||||
SetEmail(userIn.Email).
|
SetEmail(userIn.Email).
|
||||||
SetUsername(userIn.Username).
|
SetUsername(userIn.Username).
|
||||||
SetWechat(userIn.Wechat).
|
|
||||||
SetNotes(userIn.Notes).
|
SetNotes(userIn.Notes).
|
||||||
SetPasswordHash(userIn.PasswordHash).
|
SetPasswordHash(userIn.PasswordHash).
|
||||||
SetRole(userIn.Role).
|
SetRole(userIn.Role).
|
||||||
@@ -171,28 +170,38 @@ func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||||
return r.ListWithFilters(ctx, params, "", "", "")
|
return r.ListWithFilters(ctx, params, service.UserListFilters{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
|
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||||
q := r.client.User.Query()
|
q := r.client.User.Query()
|
||||||
|
|
||||||
if status != "" {
|
if filters.Status != "" {
|
||||||
q = q.Where(dbuser.StatusEQ(status))
|
q = q.Where(dbuser.StatusEQ(filters.Status))
|
||||||
}
|
}
|
||||||
if role != "" {
|
if filters.Role != "" {
|
||||||
q = q.Where(dbuser.RoleEQ(role))
|
q = q.Where(dbuser.RoleEQ(filters.Role))
|
||||||
}
|
}
|
||||||
if search != "" {
|
if filters.Search != "" {
|
||||||
q = q.Where(
|
q = q.Where(
|
||||||
dbuser.Or(
|
dbuser.Or(
|
||||||
dbuser.EmailContainsFold(search),
|
dbuser.EmailContainsFold(filters.Search),
|
||||||
dbuser.UsernameContainsFold(search),
|
dbuser.UsernameContainsFold(filters.Search),
|
||||||
dbuser.WechatContainsFold(search),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If attribute filters are specified, we need to filter by user IDs first
|
||||||
|
var allowedUserIDs []int64
|
||||||
|
if len(filters.Attributes) > 0 {
|
||||||
|
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||||||
|
if len(allowedUserIDs) == 0 {
|
||||||
|
// No users match the attribute filters
|
||||||
|
return []service.User{}, paginationResultFromTotal(0, params), nil
|
||||||
|
}
|
||||||
|
q = q.Where(dbuser.IDIn(allowedUserIDs...))
|
||||||
|
}
|
||||||
|
|
||||||
total, err := q.Clone().Count(ctx)
|
total, err := q.Clone().Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@@ -252,6 +261,59 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
|||||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
|
||||||
|
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
|
||||||
|
if len(attrs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each attribute filter, get the set of matching user IDs
|
||||||
|
// Then intersect all sets to get users matching ALL filters
|
||||||
|
var resultSet map[int64]struct{}
|
||||||
|
first := true
|
||||||
|
|
||||||
|
for attrID, value := range attrs {
|
||||||
|
// Query user_attribute_values for this attribute
|
||||||
|
values, err := r.client.UserAttributeValue.Query().
|
||||||
|
Where(
|
||||||
|
userattributevalue.AttributeIDEQ(attrID),
|
||||||
|
userattributevalue.ValueContainsFold(value),
|
||||||
|
).
|
||||||
|
All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
currentSet := make(map[int64]struct{}, len(values))
|
||||||
|
for _, v := range values {
|
||||||
|
currentSet[v.UserID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if first {
|
||||||
|
resultSet = currentSet
|
||||||
|
first = false
|
||||||
|
} else {
|
||||||
|
// Intersect with previous results
|
||||||
|
for userID := range resultSet {
|
||||||
|
if _, ok := currentSet[userID]; !ok {
|
||||||
|
delete(resultSet, userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Early exit if no users match
|
||||||
|
if len(resultSet) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]int64, 0, len(resultSet))
|
||||||
|
for userID := range resultSet {
|
||||||
|
result = append(result, userID)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
|
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ func (s *UserRepoSuite) TestListWithFilters_Status() {
|
|||||||
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
|
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
|
||||||
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
|
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
|
||||||
|
|
||||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive})
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(users, 1)
|
s.Require().Len(users, 1)
|
||||||
s.Require().Equal(service.StatusActive, users[0].Status)
|
s.Require().Equal(service.StatusActive, users[0].Status)
|
||||||
@@ -176,7 +176,7 @@ func (s *UserRepoSuite) TestListWithFilters_Role() {
|
|||||||
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
|
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
|
||||||
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
|
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
|
||||||
|
|
||||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Role: service.RoleAdmin})
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(users, 1)
|
s.Require().Len(users, 1)
|
||||||
s.Require().Equal(service.RoleAdmin, users[0].Role)
|
s.Require().Equal(service.RoleAdmin, users[0].Role)
|
||||||
@@ -186,7 +186,7 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
|
|||||||
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
|
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
|
||||||
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
|
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
|
||||||
|
|
||||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "alice"})
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(users, 1)
|
s.Require().Len(users, 1)
|
||||||
s.Require().Contains(users[0].Email, "alice")
|
s.Require().Contains(users[0].Email, "alice")
|
||||||
@@ -196,22 +196,12 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
|
|||||||
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
|
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
|
||||||
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
|
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
|
||||||
|
|
||||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "john"})
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(users, 1)
|
s.Require().Len(users, 1)
|
||||||
s.Require().Equal("JohnDoe", users[0].Username)
|
s.Require().Equal("JohnDoe", users[0].Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
|
|
||||||
s.mustCreateUser(&service.User{Email: "w1@test.com", Wechat: "wx_hello"})
|
|
||||||
s.mustCreateUser(&service.User{Email: "w2@test.com", Wechat: "wx_world"})
|
|
||||||
|
|
||||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
|
|
||||||
s.Require().NoError(err)
|
|
||||||
s.Require().Len(users, 1)
|
|
||||||
s.Require().Equal("wx_hello", users[0].Wechat)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
||||||
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
|
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
|
||||||
groupActive := s.mustCreateGroup("g-sub-active")
|
groupActive := s.mustCreateGroup("g-sub-active")
|
||||||
@@ -226,7 +216,7 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
|||||||
c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
|
c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
|
||||||
})
|
})
|
||||||
|
|
||||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@")
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "sub@"})
|
||||||
s.Require().NoError(err, "ListWithFilters")
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
s.Require().Len(users, 1, "expected 1 user")
|
s.Require().Len(users, 1, "expected 1 user")
|
||||||
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
|
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
|
||||||
@@ -238,7 +228,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
|||||||
s.mustCreateUser(&service.User{
|
s.mustCreateUser(&service.User{
|
||||||
Email: "a@example.com",
|
Email: "a@example.com",
|
||||||
Username: "Alice",
|
Username: "Alice",
|
||||||
Wechat: "wx_a",
|
|
||||||
Role: service.RoleUser,
|
Role: service.RoleUser,
|
||||||
Status: service.StatusActive,
|
Status: service.StatusActive,
|
||||||
Balance: 10,
|
Balance: 10,
|
||||||
@@ -246,7 +235,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
|||||||
target := s.mustCreateUser(&service.User{
|
target := s.mustCreateUser(&service.User{
|
||||||
Email: "b@example.com",
|
Email: "b@example.com",
|
||||||
Username: "Bob",
|
Username: "Bob",
|
||||||
Wechat: "wx_b",
|
|
||||||
Role: service.RoleAdmin,
|
Role: service.RoleAdmin,
|
||||||
Status: service.StatusActive,
|
Status: service.StatusActive,
|
||||||
Balance: 1,
|
Balance: 1,
|
||||||
@@ -257,7 +245,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
|||||||
Status: service.StatusDisabled,
|
Status: service.StatusDisabled,
|
||||||
})
|
})
|
||||||
|
|
||||||
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@")
|
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
|
||||||
s.Require().NoError(err, "ListWithFilters")
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||||
@@ -448,7 +436,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
|||||||
user1 := s.mustCreateUser(&service.User{
|
user1 := s.mustCreateUser(&service.User{
|
||||||
Email: "a@example.com",
|
Email: "a@example.com",
|
||||||
Username: "Alice",
|
Username: "Alice",
|
||||||
Wechat: "wx_a",
|
|
||||||
Role: service.RoleUser,
|
Role: service.RoleUser,
|
||||||
Status: service.StatusActive,
|
Status: service.StatusActive,
|
||||||
Balance: 10,
|
Balance: 10,
|
||||||
@@ -456,7 +443,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
|||||||
user2 := s.mustCreateUser(&service.User{
|
user2 := s.mustCreateUser(&service.User{
|
||||||
Email: "b@example.com",
|
Email: "b@example.com",
|
||||||
Username: "Bob",
|
Username: "Bob",
|
||||||
Wechat: "wx_b",
|
|
||||||
Role: service.RoleAdmin,
|
Role: service.RoleAdmin,
|
||||||
Status: service.StatusActive,
|
Status: service.StatusActive,
|
||||||
Balance: 1,
|
Balance: 1,
|
||||||
@@ -501,7 +487,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
|||||||
s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
|
s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
|
||||||
|
|
||||||
params := pagination.PaginationParams{Page: 1, PageSize: 10}
|
params := pagination.PaginationParams{Page: 1, PageSize: 10}
|
||||||
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
|
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
|
||||||
s.Require().NoError(err, "ListWithFilters")
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||||
|
|||||||
@@ -15,7 +15,14 @@ import (
|
|||||||
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
|
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
|
||||||
// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景
|
// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景
|
||||||
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
|
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
|
||||||
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes)
|
waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds())
|
||||||
|
if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout {
|
||||||
|
waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds())
|
||||||
|
}
|
||||||
|
if waitTTLSeconds <= 0 {
|
||||||
|
waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60
|
||||||
|
}
|
||||||
|
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProviderSet is the Wire provider set for all repositories
|
// ProviderSet is the Wire provider set for all repositories
|
||||||
@@ -29,6 +36,8 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewUsageLogRepository,
|
NewUsageLogRepository,
|
||||||
NewSettingRepository,
|
NewSettingRepository,
|
||||||
NewUserSubscriptionRepository,
|
NewUserSubscriptionRepository,
|
||||||
|
NewUserAttributeDefinitionRepository,
|
||||||
|
NewUserAttributeValueRepository,
|
||||||
|
|
||||||
// Cache implementations
|
// Cache implementations
|
||||||
NewGatewayCache,
|
NewGatewayCache,
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"id": 1,
|
"id": 1,
|
||||||
"email": "alice@example.com",
|
"email": "alice@example.com",
|
||||||
"username": "alice",
|
"username": "alice",
|
||||||
"wechat": "wx_alice",
|
|
||||||
"notes": "hello",
|
"notes": "hello",
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"balance": 12.5,
|
"balance": 12.5,
|
||||||
@@ -348,7 +347,6 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
ID: 1,
|
ID: 1,
|
||||||
Email: "alice@example.com",
|
Email: "alice@example.com",
|
||||||
Username: "alice",
|
Username: "alice",
|
||||||
Wechat: "wx_alice",
|
|
||||||
Notes: "hello",
|
Notes: "hello",
|
||||||
Role: service.RoleUser,
|
Role: service.RoleUser,
|
||||||
Balance: 12.5,
|
Balance: 12.5,
|
||||||
@@ -503,7 +501,7 @@ func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationPar
|
|||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
|
func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,9 @@ func RegisterAdminRoutes(
|
|||||||
|
|
||||||
// 使用记录管理
|
// 使用记录管理
|
||||||
registerUsageRoutes(admin, h)
|
registerUsageRoutes(admin, h)
|
||||||
|
|
||||||
|
// 用户属性管理
|
||||||
|
registerUserAttributeRoutes(admin, h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,6 +85,10 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
||||||
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
||||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||||
|
|
||||||
|
// User attribute values
|
||||||
|
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
||||||
|
users.PUT("/:id/attributes", h.Admin.UserAttribute.UpdateUserAttributes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,6 +117,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||||
|
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
|
||||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||||
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
||||||
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
||||||
@@ -119,6 +127,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||||
|
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
|
||||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||||
|
|
||||||
// Claude OAuth routes
|
// Claude OAuth routes
|
||||||
@@ -242,3 +251,15 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
|
attrs := admin.Group("/user-attributes")
|
||||||
|
{
|
||||||
|
attrs.GET("", h.Admin.UserAttribute.ListDefinitions)
|
||||||
|
attrs.POST("", h.Admin.UserAttribute.CreateDefinition)
|
||||||
|
attrs.POST("/batch", h.Admin.UserAttribute.GetBatchUserAttributes)
|
||||||
|
attrs.PUT("/reorder", h.Admin.UserAttribute.ReorderDefinitions)
|
||||||
|
attrs.PUT("/:id", h.Admin.UserAttribute.UpdateDefinition)
|
||||||
|
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
|
|||||||
return a.Platform == PlatformGemini
|
return a.Platform == PlatformGemini
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) GeminiOAuthType() string {
|
||||||
|
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
|
||||||
|
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
|
||||||
|
return "code_assist"
|
||||||
|
}
|
||||||
|
return oauthType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GeminiTierID() string {
|
||||||
|
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
|
||||||
|
if tierID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.ToUpper(tierID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) IsGeminiCodeAssist() bool {
|
||||||
|
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
oauthType := a.GeminiOAuthType()
|
||||||
|
if oauthType == "" {
|
||||||
|
return strings.TrimSpace(a.GetCredential("project_id")) != ""
|
||||||
|
}
|
||||||
|
return oauthType == "code_assist"
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) CanGetUsage() bool {
|
func (a *Account) CanGetUsage() bool {
|
||||||
return a.Type == AccountTypeOAuth
|
return a.Type == AccountTypeOAuth
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ var (
|
|||||||
type AccountRepository interface {
|
type AccountRepository interface {
|
||||||
Create(ctx context.Context, account *Account) error
|
Create(ctx context.Context, account *Account) error
|
||||||
GetByID(ctx context.Context, id int64) (*Account, error)
|
GetByID(ctx context.Context, id int64) (*Account, error)
|
||||||
|
// GetByIDs fetches accounts by IDs in a single query.
|
||||||
|
// It should return all accounts found (missing IDs are ignored).
|
||||||
|
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||||
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
|
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
|
||||||
ExistsByID(ctx context.Context, id int64) (bool, error)
|
ExistsByID(ctx context.Context, id int64) (bool, error)
|
||||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||||
|
|||||||
@@ -40,6 +40,10 @@ func (s *accountRepoStub) GetByID(ctx context.Context, id int64) (*Account, erro
|
|||||||
panic("unexpected GetByID call")
|
panic("unexpected GetByID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStub) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
|
||||||
|
panic("unexpected GetByIDs call")
|
||||||
|
}
|
||||||
|
|
||||||
// ExistsByID 返回预设的存在性检查结果。
|
// ExistsByID 返回预设的存在性检查结果。
|
||||||
// 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。
|
// 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。
|
||||||
func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||||
|
|||||||
@@ -93,10 +93,12 @@ type UsageProgress struct {
|
|||||||
|
|
||||||
// UsageInfo 账号使用量信息
|
// UsageInfo 账号使用量信息
|
||||||
type UsageInfo struct {
|
type UsageInfo struct {
|
||||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||||
|
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
||||||
|
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||||
@@ -122,17 +124,19 @@ type ClaudeUsageFetcher interface {
|
|||||||
|
|
||||||
// AccountUsageService 账号使用量查询服务
|
// AccountUsageService 账号使用量查询服务
|
||||||
type AccountUsageService struct {
|
type AccountUsageService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
usageLogRepo UsageLogRepository
|
usageLogRepo UsageLogRepository
|
||||||
usageFetcher ClaudeUsageFetcher
|
usageFetcher ClaudeUsageFetcher
|
||||||
|
geminiQuotaService *GeminiQuotaService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountUsageService 创建AccountUsageService实例
|
// NewAccountUsageService 创建AccountUsageService实例
|
||||||
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
|
||||||
return &AccountUsageService{
|
return &AccountUsageService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
usageLogRepo: usageLogRepo,
|
usageLogRepo: usageLogRepo,
|
||||||
usageFetcher: usageFetcher,
|
usageFetcher: usageFetcher,
|
||||||
|
geminiQuotaService: geminiQuotaService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
return nil, fmt.Errorf("get account failed: %w", err)
|
return nil, fmt.Errorf("get account failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if account.Platform == PlatformGemini {
|
||||||
|
return s.getGeminiUsage(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
||||||
if account.CanGetUsage() {
|
if account.CanGetUsage() {
|
||||||
var apiResp *ClaudeUsageResponse
|
var apiResp *ClaudeUsageResponse
|
||||||
@@ -192,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||||
|
now := time.Now()
|
||||||
|
usage := &UsageInfo{
|
||||||
|
UpdatedAt: &now,
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
|
||||||
|
if !ok {
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
start := geminiDailyWindowStart(now)
|
||||||
|
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
totals := geminiAggregateUsage(stats)
|
||||||
|
resetAt := geminiDailyResetTime(now)
|
||||||
|
|
||||||
|
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
|
||||||
|
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
// addWindowStats 为 usage 数据添加窗口期统计
|
// addWindowStats 为 usage 数据添加窗口期统计
|
||||||
// 使用独立缓存(1 分钟),与 API 缓存分离
|
// 使用独立缓存(1 分钟),与 API 缓存分离
|
||||||
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
|
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||||
@@ -388,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
|||||||
// Setup Token无法获取7d数据
|
// Setup Token无法获取7d数据
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
|
||||||
|
if limit <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
utilization := (float64(used) / float64(limit)) * 100
|
||||||
|
remainingSeconds := int(resetAt.Sub(now).Seconds())
|
||||||
|
if remainingSeconds < 0 {
|
||||||
|
remainingSeconds = 0
|
||||||
|
}
|
||||||
|
resetCopy := resetAt
|
||||||
|
return &UsageProgress{
|
||||||
|
Utilization: utilization,
|
||||||
|
ResetsAt: &resetCopy,
|
||||||
|
RemainingSeconds: remainingSeconds,
|
||||||
|
WindowStats: &WindowStats{
|
||||||
|
Requests: used,
|
||||||
|
Tokens: tokens,
|
||||||
|
Cost: cost,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
// AdminService interface defines admin management operations
|
// AdminService interface defines admin management operations
|
||||||
type AdminService interface {
|
type AdminService interface {
|
||||||
// User management
|
// User management
|
||||||
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error)
|
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error)
|
||||||
GetUser(ctx context.Context, id int64) (*User, error)
|
GetUser(ctx context.Context, id int64) (*User, error)
|
||||||
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
|
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
|
||||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
||||||
@@ -35,6 +35,7 @@ type AdminService interface {
|
|||||||
// Account management
|
// Account management
|
||||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
||||||
GetAccount(ctx context.Context, id int64) (*Account, error)
|
GetAccount(ctx context.Context, id int64) (*Account, error)
|
||||||
|
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
||||||
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
|
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
|
||||||
DeleteAccount(ctx context.Context, id int64) error
|
DeleteAccount(ctx context.Context, id int64) error
|
||||||
@@ -69,7 +70,6 @@ type CreateUserInput struct {
|
|||||||
Email string
|
Email string
|
||||||
Password string
|
Password string
|
||||||
Username string
|
Username string
|
||||||
Wechat string
|
|
||||||
Notes string
|
Notes string
|
||||||
Balance float64
|
Balance float64
|
||||||
Concurrency int
|
Concurrency int
|
||||||
@@ -80,7 +80,6 @@ type UpdateUserInput struct {
|
|||||||
Email string
|
Email string
|
||||||
Password string
|
Password string
|
||||||
Username *string
|
Username *string
|
||||||
Wechat *string
|
|
||||||
Notes *string
|
Notes *string
|
||||||
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
||||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||||
@@ -251,9 +250,9 @@ func NewAdminService(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// User management implementations
|
// User management implementations
|
||||||
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) {
|
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) {
|
||||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
|
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
@@ -268,7 +267,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
|
|||||||
user := &User{
|
user := &User{
|
||||||
Email: input.Email,
|
Email: input.Email,
|
||||||
Username: input.Username,
|
Username: input.Username,
|
||||||
Wechat: input.Wechat,
|
|
||||||
Notes: input.Notes,
|
Notes: input.Notes,
|
||||||
Role: RoleUser, // Always create as regular user, never admin
|
Role: RoleUser, // Always create as regular user, never admin
|
||||||
Balance: input.Balance,
|
Balance: input.Balance,
|
||||||
@@ -310,9 +308,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
if input.Username != nil {
|
if input.Username != nil {
|
||||||
user.Username = *input.Username
|
user.Username = *input.Username
|
||||||
}
|
}
|
||||||
if input.Wechat != nil {
|
|
||||||
user.Wechat = *input.Wechat
|
|
||||||
}
|
|
||||||
if input.Notes != nil {
|
if input.Notes != nil {
|
||||||
user.Notes = *input.Notes
|
user.Notes = *input.Notes
|
||||||
}
|
}
|
||||||
@@ -611,6 +606,19 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
|
|||||||
return s.accountRepo.GetByID(ctx, id)
|
return s.accountRepo.GetByID(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return []*Account{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get accounts by IDs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accounts, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
|
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Name: input.Name,
|
Name: input.Name,
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
|
|||||||
Email: "user@test.com",
|
Email: "user@test.com",
|
||||||
Password: "strong-pass",
|
Password: "strong-pass",
|
||||||
Username: "tester",
|
Username: "tester",
|
||||||
Wechat: "wx",
|
|
||||||
Notes: "note",
|
Notes: "note",
|
||||||
Balance: 12.5,
|
Balance: 12.5,
|
||||||
Concurrency: 7,
|
Concurrency: 7,
|
||||||
@@ -31,7 +30,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
|
|||||||
require.Equal(t, int64(10), user.ID)
|
require.Equal(t, int64(10), user.ID)
|
||||||
require.Equal(t, input.Email, user.Email)
|
require.Equal(t, input.Email, user.Email)
|
||||||
require.Equal(t, input.Username, user.Username)
|
require.Equal(t, input.Username, user.Username)
|
||||||
require.Equal(t, input.Wechat, user.Wechat)
|
|
||||||
require.Equal(t, input.Notes, user.Notes)
|
require.Equal(t, input.Notes, user.Notes)
|
||||||
require.Equal(t, input.Balance, user.Balance)
|
require.Equal(t, input.Balance, user.Balance)
|
||||||
require.Equal(t, input.Concurrency, user.Concurrency)
|
require.Equal(t, input.Concurrency, user.Concurrency)
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationPar
|
|||||||
panic("unexpected List call")
|
panic("unexpected List call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error) {
|
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||||
panic("unexpected ListWithFilters call")
|
panic("unexpected ListWithFilters call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ const (
|
|||||||
antigravityRetryMaxDelay = 16 * time.Second
|
antigravityRetryMaxDelay = 16 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// Antigravity 直接支持的模型
|
// Antigravity 直接支持的模型(精确匹配透传)
|
||||||
var antigravitySupportedModels = map[string]bool{
|
var antigravitySupportedModels = map[string]bool{
|
||||||
"claude-opus-4-5-thinking": true,
|
"claude-opus-4-5-thinking": true,
|
||||||
"claude-sonnet-4-5": true,
|
"claude-sonnet-4-5": true,
|
||||||
@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
|
|||||||
"gemini-3-flash": true,
|
"gemini-3-flash": true,
|
||||||
"gemini-3-pro-low": true,
|
"gemini-3-pro-low": true,
|
||||||
"gemini-3-pro-high": true,
|
"gemini-3-pro-high": true,
|
||||||
"gemini-3-pro-preview": true,
|
|
||||||
"gemini-3-pro-image": true,
|
"gemini-3-pro-image": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Antigravity 系统默认模型映射表(不支持 → 支持)
|
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
|
||||||
var antigravityModelMapping = map[string]string{
|
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
|
||||||
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5",
|
var antigravityPrefixMapping = []struct {
|
||||||
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5",
|
prefix string
|
||||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking",
|
target string
|
||||||
"claude-opus-4": "claude-opus-4-5-thinking",
|
}{
|
||||||
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
|
// 长前缀优先
|
||||||
"claude-haiku-4": "gemini-3-flash",
|
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||||
"claude-haiku-4-5": "gemini-3-flash",
|
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||||
"claude-3-haiku-20240307": "gemini-3-flash",
|
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||||
"claude-haiku-4-5-20251001": "gemini-3-flash",
|
{"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx
|
||||||
// 生图模型:官方名 → Antigravity 内部名
|
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
|
||||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
{"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx
|
||||||
|
{"claude-sonnet-4", "claude-sonnet-4-5"},
|
||||||
|
{"claude-haiku-4", "gemini-3-flash"},
|
||||||
|
{"claude-opus-4", "claude-opus-4-5-thinking"},
|
||||||
|
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
|
||||||
}
|
}
|
||||||
|
|
||||||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||||||
@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getMappedModel 获取映射后的模型名
|
// getMappedModel 获取映射后的模型名
|
||||||
|
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
|
||||||
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
|
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
|
||||||
// 1. 优先使用账户级映射(复用现有方法)
|
// 1. 账户级映射(用户自定义优先)
|
||||||
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
|
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
|
||||||
return mapped
|
return mapped
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 系统默认映射
|
// 2. 直接支持的模型透传
|
||||||
if mapped, ok := antigravityModelMapping[requestedModel]; ok {
|
if antigravitySupportedModels[requestedModel] {
|
||||||
return mapped
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Gemini 模型透传
|
|
||||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
|
||||||
return requestedModel
|
return requestedModel
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Claude 前缀透传直接支持的模型
|
// 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview)
|
||||||
if antigravitySupportedModels[requestedModel] {
|
for _, pm := range antigravityPrefixMapping {
|
||||||
|
if strings.HasPrefix(requestedModel, pm.prefix) {
|
||||||
|
return pm.target
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Gemini 模型透传(未匹配到前缀的 gemini 模型)
|
||||||
|
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||||
return requestedModel
|
return requestedModel
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsModelSupported 检查模型是否被支持
|
// IsModelSupported 检查模型是否被支持
|
||||||
|
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
|
||||||
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
|
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
|
||||||
// 直接支持的模型
|
return strings.HasPrefix(requestedModel, "claude-") ||
|
||||||
if antigravitySupportedModels[requestedModel] {
|
strings.HasPrefix(requestedModel, "gemini-")
|
||||||
return true
|
|
||||||
}
|
|
||||||
// 可映射的模型
|
|
||||||
if _, ok := antigravityModelMapping[requestedModel]; ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Gemini 前缀透传
|
|
||||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Claude 模型支持(通过默认映射)
|
|
||||||
if strings.HasPrefix(requestedModel, "claude-") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestConnectionResult 测试连接结果
|
// TestConnectionResult 测试连接结果
|
||||||
@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return nil, fmt.Errorf("transform request: %w", err)
|
return nil, fmt.Errorf("transform request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 调试:记录转换后的请求体(仅记录前 2000 字符)
|
||||||
|
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
|
||||||
|
truncated := string(bodyJSON)
|
||||||
|
if len(truncated) > 2000 {
|
||||||
|
truncated = truncated[:2000] + "..."
|
||||||
|
}
|
||||||
|
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
|
||||||
|
}
|
||||||
|
|
||||||
// 构建上游 action
|
// 构建上游 action
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if claudeReq.Stream {
|
if claudeReq.Stream {
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
|||||||
name: "系统映射 - claude-sonnet-4-5-20250929",
|
name: "系统映射 - claude-sonnet-4-5-20250929",
|
||||||
requestedModel: "claude-sonnet-4-5-20250929",
|
requestedModel: "claude-sonnet-4-5-20250929",
|
||||||
accountMapping: nil,
|
accountMapping: nil,
|
||||||
expected: "claude-sonnet-4-5-thinking",
|
expected: "claude-sonnet-4-5",
|
||||||
},
|
},
|
||||||
|
|
||||||
// 3. Gemini 透传
|
// 3. Gemini 透传
|
||||||
|
|||||||
@@ -18,6 +18,11 @@ type ConcurrencyCache interface {
|
|||||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||||
|
|
||||||
|
// 账号等待队列(账号级)
|
||||||
|
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
|
||||||
|
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
|
||||||
|
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
|
||||||
|
|
||||||
// 用户槽位管理
|
// 用户槽位管理
|
||||||
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
|
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
|
||||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||||
@@ -27,6 +32,12 @@ type ConcurrencyCache interface {
|
|||||||
// 等待队列计数(只在首次创建时设置 TTL)
|
// 等待队列计数(只在首次创建时设置 TTL)
|
||||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||||
|
|
||||||
|
// 批量负载查询(只读)
|
||||||
|
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||||||
|
|
||||||
|
// 清理过期槽位(后台任务)
|
||||||
|
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateRequestID generates a unique request ID for concurrency slot tracking
|
// generateRequestID generates a unique request ID for concurrency slot tracking
|
||||||
@@ -61,6 +72,18 @@ type AcquireResult struct {
|
|||||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AccountWithConcurrency struct {
|
||||||
|
ID int64
|
||||||
|
MaxConcurrency int
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccountLoadInfo struct {
|
||||||
|
AccountID int64
|
||||||
|
CurrentConcurrency int
|
||||||
|
WaitingCount int
|
||||||
|
LoadRate int // 0-100+ (percent)
|
||||||
|
}
|
||||||
|
|
||||||
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
|
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
|
||||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||||
// Returns a release function that MUST be called when the request completes.
|
// Returns a release function that MUST be called when the request completes.
|
||||||
@@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||||||
|
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
|
if s.cache == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecrementAccountWaitCount decrements the wait queue counter for an account.
|
||||||
|
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||||||
|
if s.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
|
||||||
|
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountWaitingCount gets current wait queue count for an account.
|
||||||
|
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
if s.cache == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return s.cache.GetAccountWaitingCount(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||||
func CalculateMaxWait(userConcurrency int) int {
|
func CalculateMaxWait(userConcurrency int) int {
|
||||||
@@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int {
|
|||||||
return userConcurrency + defaultExtraWaitSlots
|
return userConcurrency + defaultExtraWaitSlots
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountsLoadBatch returns load info for multiple accounts.
|
||||||
|
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||||
|
if s.cache == nil {
|
||||||
|
return map[int64]*AccountLoadInfo{}, nil
|
||||||
|
}
|
||||||
|
return s.cache.GetAccountsLoadBatch(ctx, accounts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
|
||||||
|
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||||
|
if s.cache == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
|
||||||
|
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
|
||||||
|
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
runCleanup := func() {
|
||||||
|
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
accounts, err := accountRepo.ListSchedulable(listCtx)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Warning: list schedulable accounts failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, account := range accounts {
|
||||||
|
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
|
||||||
|
accountCancel()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
runCleanup()
|
||||||
|
for range ticker.C {
|
||||||
|
runCleanup()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||||
// Returns a map of accountID -> current concurrency count
|
// Returns a map of accountID -> current concurrency count
|
||||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||||
|
|||||||
@@ -91,6 +91,9 @@ const (
|
|||||||
|
|
||||||
// 管理员 API Key
|
// 管理员 API Key
|
||||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||||
|
|
||||||
|
// Gemini 配额策略(JSON)
|
||||||
|
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Admin API Key prefix (distinct from user "sk-" keys)
|
// Admin API Key prefix (distinct from user "sk-" keys)
|
||||||
|
|||||||
@@ -32,6 +32,16 @@ func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Ac
|
|||||||
return nil, errors.New("account not found")
|
return nil, errors.New("account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAccountRepoForPlatform) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
|
||||||
|
var result []*Account
|
||||||
|
for _, id := range ids {
|
||||||
|
if acc, ok := m.accountsByID[id]; ok {
|
||||||
|
result = append(result, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAccountRepoForPlatform) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
func (m *mockAccountRepoForPlatform) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||||
if m.accountsByID == nil {
|
if m.accountsByID == nil {
|
||||||
return false, nil
|
return false, nil
|
||||||
@@ -261,6 +271,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
|
|||||||
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
|
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||||
|
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: testConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
|
||||||
|
}
|
||||||
|
|
||||||
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
|
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
|
||||||
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
|
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@@ -576,6 +614,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
|||||||
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||||
|
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: testConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
|
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
|
||||||
repo := &mockAccountRepoForPlatform{
|
repo := &mockAccountRepoForPlatform{
|
||||||
accounts: []Account{
|
accounts: []Account{
|
||||||
@@ -783,3 +847,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mockConcurrencyService for testing
|
||||||
|
type mockConcurrencyService struct {
|
||||||
|
accountLoads map[int64]*AccountLoadInfo
|
||||||
|
accountWaitCounts map[int64]int
|
||||||
|
acquireResults map[int64]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||||
|
if m.accountLoads == nil {
|
||||||
|
return map[int64]*AccountLoadInfo{}, nil
|
||||||
|
}
|
||||||
|
result := make(map[int64]*AccountLoadInfo)
|
||||||
|
for _, acc := range accounts {
|
||||||
|
if load, ok := m.accountLoads[acc.ID]; ok {
|
||||||
|
result[acc.ID] = load
|
||||||
|
} else {
|
||||||
|
result[acc.ID] = &AccountLoadInfo{
|
||||||
|
AccountID: acc.ID,
|
||||||
|
CurrentConcurrency: 0,
|
||||||
|
WaitingCount: 0,
|
||||||
|
LoadRate: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
if m.accountWaitCounts == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return m.accountWaitCounts[accountID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
|
||||||
|
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||||
|
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: nil, // No concurrency service
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||||
|
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||||
|
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
excludedIDs := map[int64]struct{}{1: {}}
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("无可用账号-返回错误", func(t *testing.T) {
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Contains(t, err.Error(), "no available accounts")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,12 +13,14 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -66,6 +68,20 @@ type GatewayCache interface {
|
|||||||
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AccountWaitPlan struct {
|
||||||
|
AccountID int64
|
||||||
|
MaxConcurrency int
|
||||||
|
Timeout time.Duration
|
||||||
|
MaxWaiting int
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccountSelectionResult struct {
|
||||||
|
Account *Account
|
||||||
|
Acquired bool
|
||||||
|
ReleaseFunc func()
|
||||||
|
WaitPlan *AccountWaitPlan // nil means no wait allowed
|
||||||
|
}
|
||||||
|
|
||||||
// ClaudeUsage 表示Claude API返回的usage信息
|
// ClaudeUsage 表示Claude API返回的usage信息
|
||||||
type ClaudeUsage struct {
|
type ClaudeUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
@@ -108,6 +124,7 @@ type GatewayService struct {
|
|||||||
identityService *IdentityService
|
identityService *IdentityService
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
deferredService *DeferredService
|
deferredService *DeferredService
|
||||||
|
concurrencyService *ConcurrencyService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGatewayService creates a new GatewayService
|
// NewGatewayService creates a new GatewayService
|
||||||
@@ -119,6 +136,7 @@ func NewGatewayService(
|
|||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
cache GatewayCache,
|
cache GatewayCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
|
concurrencyService *ConcurrencyService,
|
||||||
billingService *BillingService,
|
billingService *BillingService,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
@@ -134,6 +152,7 @@ func NewGatewayService(
|
|||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
concurrencyService: concurrencyService,
|
||||||
billingService: billingService,
|
billingService: billingService,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
@@ -183,6 +202,14 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BindStickySession sets session -> account binding with standard TTL.
|
||||||
|
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||||
|
if sessionHash == "" || accountID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||||
if parsed == nil {
|
if parsed == nil {
|
||||||
return ""
|
return ""
|
||||||
@@ -332,8 +359,354 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
|||||||
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||||
|
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||||
|
cfg := s.schedulingConfig()
|
||||||
|
var stickyAccountID int64
|
||||||
|
if sessionHash != "" && s.cache != nil {
|
||||||
|
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
|
||||||
|
stickyAccountID = accountID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||||
|
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||||
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||||
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: account.ID,
|
||||||
|
MaxConcurrency: account.Concurrency,
|
||||||
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: account.ID,
|
||||||
|
MaxConcurrency: account.Concurrency,
|
||||||
|
Timeout: cfg.FallbackWaitTimeout,
|
||||||
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
preferOAuth := platform == PlatformGemini
|
||||||
|
|
||||||
|
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(accounts) == 0 {
|
||||||
|
return nil, errors.New("no available accounts")
|
||||||
|
}
|
||||||
|
|
||||||
|
isExcluded := func(accountID int64) bool {
|
||||||
|
if excludedIDs == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, excluded := excludedIDs[accountID]
|
||||||
|
return excluded
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Layer 1: 粘性会话优先 ============
|
||||||
|
if sessionHash != "" {
|
||||||
|
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||||
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||||
|
account.IsSchedulable() &&
|
||||||
|
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||||
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: accountID,
|
||||||
|
MaxConcurrency: account.Concurrency,
|
||||||
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Layer 2: 负载感知选择 ============
|
||||||
|
candidates := make([]*Account, 0, len(accounts))
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
if isExcluded(acc.ID) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidates = append(candidates, acc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil, errors.New("no available accounts")
|
||||||
|
}
|
||||||
|
|
||||||
|
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||||
|
for _, acc := range candidates {
|
||||||
|
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||||
|
ID: acc.ID,
|
||||||
|
MaxConcurrency: acc.Concurrency,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||||
|
if err != nil {
|
||||||
|
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
type accountWithLoad struct {
|
||||||
|
account *Account
|
||||||
|
loadInfo *AccountLoadInfo
|
||||||
|
}
|
||||||
|
var available []accountWithLoad
|
||||||
|
for _, acc := range candidates {
|
||||||
|
loadInfo := loadMap[acc.ID]
|
||||||
|
if loadInfo == nil {
|
||||||
|
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||||
|
}
|
||||||
|
if loadInfo.LoadRate < 100 {
|
||||||
|
available = append(available, accountWithLoad{
|
||||||
|
account: acc,
|
||||||
|
loadInfo: loadInfo,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(available) > 0 {
|
||||||
|
sort.SliceStable(available, func(i, j int) bool {
|
||||||
|
a, b := available[i], available[j]
|
||||||
|
if a.account.Priority != b.account.Priority {
|
||||||
|
return a.account.Priority < b.account.Priority
|
||||||
|
}
|
||||||
|
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||||
|
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||||
|
return true
|
||||||
|
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||||
|
return false
|
||||||
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||||
|
if preferOAuth && a.account.Type != b.account.Type {
|
||||||
|
return a.account.Type == AccountTypeOAuth
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, item := range available {
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
if sessionHash != "" {
|
||||||
|
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: item.account,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Layer 3: 兜底排队 ============
|
||||||
|
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
|
||||||
|
for _, acc := range candidates {
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: acc,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: acc.ID,
|
||||||
|
MaxConcurrency: acc.Concurrency,
|
||||||
|
Timeout: cfg.FallbackWaitTimeout,
|
||||||
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("no available accounts")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||||
|
ordered := append([]*Account(nil), candidates...)
|
||||||
|
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||||
|
|
||||||
|
for _, acc := range ordered {
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
if sessionHash != "" {
|
||||||
|
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: acc,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||||
|
if s.cfg != nil {
|
||||||
|
return s.cfg.Gateway.Scheduling
|
||||||
|
}
|
||||||
|
return config.GatewaySchedulingConfig{
|
||||||
|
StickySessionMaxWaiting: 3,
|
||||||
|
StickySessionWaitTimeout: 45 * time.Second,
|
||||||
|
FallbackWaitTimeout: 30 * time.Second,
|
||||||
|
FallbackMaxWaiting: 100,
|
||||||
|
LoadBatchEnabled: true,
|
||||||
|
SlotCleanupInterval: 30 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
|
||||||
|
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||||
|
if hasForcePlatform && forcePlatform != "" {
|
||||||
|
return forcePlatform, true, nil
|
||||||
|
}
|
||||||
|
if groupID != nil {
|
||||||
|
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, fmt.Errorf("get group failed: %w", err)
|
||||||
|
}
|
||||||
|
return group.Platform, false, nil
|
||||||
|
}
|
||||||
|
return PlatformAnthropic, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||||
|
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||||
|
if useMixed {
|
||||||
|
platforms := []string{platform, PlatformAntigravity}
|
||||||
|
var accounts []Account
|
||||||
|
var err error
|
||||||
|
if groupID != nil {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
||||||
|
} else {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, useMixed, err
|
||||||
|
}
|
||||||
|
filtered := make([]Account, 0, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, acc)
|
||||||
|
}
|
||||||
|
return filtered, useMixed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var accounts []Account
|
||||||
|
var err error
|
||||||
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||||
|
} else if groupID != nil {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||||
|
if err == nil && len(accounts) == 0 && hasForcePlatform {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, useMixed, err
|
||||||
|
}
|
||||||
|
return accounts, useMixed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
|
||||||
|
if account == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if useMixed {
|
||||||
|
if account.Platform == platform {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
|
||||||
|
}
|
||||||
|
return account.Platform == platform
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||||
|
if s.concurrencyService == nil {
|
||||||
|
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
|
||||||
|
}
|
||||||
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||||
|
sort.SliceStable(accounts, func(i, j int) bool {
|
||||||
|
a, b := accounts[i], accounts[j]
|
||||||
|
if a.Priority != b.Priority {
|
||||||
|
return a.Priority < b.Priority
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case a.LastUsedAt == nil && b.LastUsedAt != nil:
|
||||||
|
return true
|
||||||
|
case a.LastUsedAt != nil && b.LastUsedAt == nil:
|
||||||
|
return false
|
||||||
|
case a.LastUsedAt == nil && b.LastUsedAt == nil:
|
||||||
|
if preferOAuth && a.Type != b.Type {
|
||||||
|
return a.Type == AccountTypeOAuth
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return a.LastUsedAt.Before(*b.LastUsedAt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||||
|
preferOAuth := platform == PlatformGemini
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||||
@@ -389,7 +762,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||||
// keep selected (never used is preferred)
|
// keep selected (never used is preferred)
|
||||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||||
// keep selected (both never used)
|
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||||
|
selected = acc
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||||
selected = acc
|
selected = acc
|
||||||
@@ -419,6 +794,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||||
platforms := []string{nativePlatform, PlatformAntigravity}
|
platforms := []string{nativePlatform, PlatformAntigravity}
|
||||||
|
preferOAuth := nativePlatform == PlatformGemini
|
||||||
|
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
@@ -478,7 +854,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||||
// keep selected (never used is preferred)
|
// keep selected (never used is preferred)
|
||||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||||
// keep selected (both never used)
|
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||||
|
selected = acc
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||||
selected = acc
|
selected = acc
|
||||||
@@ -515,24 +893,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
|
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
|
||||||
|
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
|
||||||
func IsAntigravityModelSupported(requestedModel string) bool {
|
func IsAntigravityModelSupported(requestedModel string) bool {
|
||||||
// 直接支持的模型
|
return strings.HasPrefix(requestedModel, "claude-") ||
|
||||||
if antigravitySupportedModels[requestedModel] {
|
strings.HasPrefix(requestedModel, "gemini-")
|
||||||
return true
|
|
||||||
}
|
|
||||||
// 可映射的模型
|
|
||||||
if _, ok := antigravityModelMapping[requestedModel]; ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Gemini 前缀透传
|
|
||||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5)
|
|
||||||
if strings.HasPrefix(requestedModel, "claude-") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessToken 获取账号凭证
|
// GetAccessToken 获取账号凭证
|
||||||
@@ -684,6 +1048,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
|
|
||||||
// 处理错误响应(不可重试的错误)
|
// 处理错误响应(不可重试的错误)
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||||
|
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||||
|
respBody, readErr := io.ReadAll(resp.Body)
|
||||||
|
if readErr != nil {
|
||||||
|
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||||
|
return s.handleErrorResponse(ctx, resp, c, account)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
|
if s.shouldFailoverOn400(respBody) {
|
||||||
|
if s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
log.Printf(
|
||||||
|
"Account %d: 400 error, attempting failover: %s",
|
||||||
|
account.ID,
|
||||||
|
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
log.Printf("Account %d: 400 error, attempting failover", account.ID)
|
||||||
|
}
|
||||||
|
s.handleFailoverSideEffects(ctx, resp, account)
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
|
}
|
||||||
|
}
|
||||||
return s.handleErrorResponse(ctx, resp, c, account)
|
return s.handleErrorResponse(ctx, resp, c, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -786,6 +1174,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" {
|
||||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||||
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||||
|
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||||
|
if requestNeedsBetaFeatures(body) {
|
||||||
|
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||||
|
req.Header.Set("anthropic-beta", beta)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
@@ -838,6 +1233,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
|
|||||||
return claude.DefaultBetaHeader
|
return claude.DefaultBetaHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requestNeedsBetaFeatures(body []byte) bool {
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultApiKeyBetaHeader(body []byte) string {
|
||||||
|
modelID := gjson.GetBytes(body, "model").String()
|
||||||
|
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||||
|
return claude.ApiKeyHaikuBetaHeader
|
||||||
|
}
|
||||||
|
return claude.ApiKeyBetaHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateForLog(b []byte, maxBytes int) string {
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
if len(b) > maxBytes {
|
||||||
|
b = b[:maxBytes]
|
||||||
|
}
|
||||||
|
s := string(b)
|
||||||
|
// 保持一行,避免污染日志格式
|
||||||
|
s = strings.ReplaceAll(s, "\n", "\\n")
|
||||||
|
s = strings.ReplaceAll(s, "\r", "\\r")
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||||
|
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
|
||||||
|
// 默认保守:无法识别则不切换。
|
||||||
|
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||||
|
if msg == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。
|
||||||
|
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
|
||||||
|
if strings.Contains(msg, "anthropic-beta") ||
|
||||||
|
strings.Contains(msg, "beta feature") ||
|
||||||
|
strings.Contains(msg, "requires beta") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
|
||||||
|
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractUpstreamErrorMessage(body []byte) string {
|
||||||
|
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
|
||||||
|
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
|
||||||
|
inner := strings.TrimSpace(m)
|
||||||
|
// 有些上游会把完整 JSON 作为字符串塞进 message
|
||||||
|
if strings.HasPrefix(inner, "{") {
|
||||||
|
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
|
||||||
|
return innerMsg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// 兜底:尝试顶层 message
|
||||||
|
return gjson.GetBytes(body, "message").String()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
@@ -850,6 +1322,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
|||||||
|
|
||||||
switch resp.StatusCode {
|
switch resp.StatusCode {
|
||||||
case 400:
|
case 400:
|
||||||
|
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
log.Printf(
|
||||||
|
"Upstream 400 error (account=%d platform=%s type=%s): %s",
|
||||||
|
account.ID,
|
||||||
|
account.Platform,
|
||||||
|
account.Type,
|
||||||
|
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||||
|
)
|
||||||
|
}
|
||||||
c.Data(http.StatusBadRequest, "application/json", body)
|
c.Data(http.StatusBadRequest, "application/json", body)
|
||||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||||
case 401:
|
case 401:
|
||||||
@@ -1329,6 +1811,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
// 标记账号状态(429/529等)
|
// 标记账号状态(429/529等)
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
|
||||||
|
// 记录上游错误摘要便于排障(不回显请求内容)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
log.Printf(
|
||||||
|
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
|
||||||
|
resp.StatusCode,
|
||||||
|
account.ID,
|
||||||
|
account.Platform,
|
||||||
|
account.Type,
|
||||||
|
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// 返回简化的错误响应
|
// 返回简化的错误响应
|
||||||
errMsg := "Upstream request failed"
|
errMsg := "Upstream request failed"
|
||||||
switch resp.StatusCode {
|
switch resp.StatusCode {
|
||||||
@@ -1409,6 +1903,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
// OAuth 账号:处理 anthropic-beta header
|
// OAuth 账号:处理 anthropic-beta header
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" {
|
||||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||||
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||||
|
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||||
|
if requestNeedsBetaFeatures(body) {
|
||||||
|
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||||
|
req.Header.Set("anthropic-beta", beta)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
@@ -1424,3 +1925,58 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAvailableModels returns the list of models available for a group
|
||||||
|
// It aggregates model_mapping keys from all schedulable accounts in the group
|
||||||
|
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
||||||
|
var accounts []Account
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if groupID != nil {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
|
||||||
|
} else {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulable(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil || len(accounts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter by platform if specified
|
||||||
|
if platform != "" {
|
||||||
|
filtered := make([]Account, 0)
|
||||||
|
for _, acc := range accounts {
|
||||||
|
if acc.Platform == platform {
|
||||||
|
filtered = append(filtered, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
accounts = filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect unique models from all accounts
|
||||||
|
modelSet := make(map[string]struct{})
|
||||||
|
hasAnyMapping := false
|
||||||
|
|
||||||
|
for _, acc := range accounts {
|
||||||
|
mapping := acc.GetModelMapping()
|
||||||
|
if len(mapping) > 0 {
|
||||||
|
hasAnyMapping = true
|
||||||
|
for model := range mapping {
|
||||||
|
modelSet[model] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no account has model_mapping, return nil (use default)
|
||||||
|
if !hasAnyMapping {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to slice
|
||||||
|
models := make([]string, 0, len(modelSet))
|
||||||
|
for model := range modelSet {
|
||||||
|
models = append(models, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|||||||
@@ -116,8 +116,20 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
valid = true
|
valid = true
|
||||||
}
|
}
|
||||||
if valid {
|
if valid {
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
usable := true
|
||||||
return account, nil
|
if s.rateLimitService != nil && requestedModel != "" {
|
||||||
|
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
usable = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if usable {
|
||||||
|
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -157,6 +169,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if s.rateLimitService != nil && requestedModel != "" {
|
||||||
|
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
if selected == nil {
|
if selected == nil {
|
||||||
selected = acc
|
selected = acc
|
||||||
continue
|
continue
|
||||||
@@ -1886,13 +1907,44 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
|
|||||||
if statusCode != 429 {
|
if statusCode != 429 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oauthType := account.GeminiOAuthType()
|
||||||
|
tierID := account.GeminiTierID()
|
||||||
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
|
isCodeAssist := account.IsGeminiCodeAssist()
|
||||||
|
|
||||||
resetAt := ParseGeminiRateLimitResetTime(body)
|
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||||
if resetAt == nil {
|
if resetAt == nil {
|
||||||
ra := time.Now().Add(5 * time.Minute)
|
// 根据账号类型使用不同的默认重置时间
|
||||||
|
var ra time.Time
|
||||||
|
if isCodeAssist {
|
||||||
|
// Code Assist: fallback cooldown by tier
|
||||||
|
cooldown := geminiCooldownForTier(tierID)
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
|
||||||
|
}
|
||||||
|
ra = time.Now().Add(cooldown)
|
||||||
|
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
|
||||||
|
} else {
|
||||||
|
// API Key / AI Studio OAuth: PST 午夜
|
||||||
|
if ts := nextGeminiDailyResetUnix(); ts != nil {
|
||||||
|
ra = time.Unix(*ts, 0)
|
||||||
|
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
|
||||||
|
} else {
|
||||||
|
// 兜底:5 分钟
|
||||||
|
ra = time.Now().Add(5 * time.Minute)
|
||||||
|
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
|
|
||||||
|
// 使用解析到的重置时间
|
||||||
|
resetTime := time.Unix(*resetAt, 0)
|
||||||
|
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
|
||||||
|
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
|
||||||
|
account.ID, resetTime, oauthType, tierID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
|
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||||
@@ -1948,16 +2000,7 @@ func looksLikeGeminiDailyQuota(message string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func nextGeminiDailyResetUnix() *int64 {
|
func nextGeminiDailyResetUnix() *int64 {
|
||||||
loc, err := time.LoadLocation("America/Los_Angeles")
|
reset := geminiDailyResetTime(time.Now())
|
||||||
if err != nil {
|
|
||||||
// Fallback: PST without DST.
|
|
||||||
loc = time.FixedZone("PST", -8*3600)
|
|
||||||
}
|
|
||||||
now := time.Now().In(loc)
|
|
||||||
reset := time.Date(now.Year(), now.Month(), now.Day(), 0, 5, 0, 0, loc)
|
|
||||||
if !reset.After(now) {
|
|
||||||
reset = reset.Add(24 * time.Hour)
|
|
||||||
}
|
|
||||||
ts := reset.Unix()
|
ts := reset.Unix()
|
||||||
return &ts
|
return &ts
|
||||||
}
|
}
|
||||||
@@ -2278,11 +2321,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
|
|||||||
"properties": map[string]any{},
|
"properties": map[string]any{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 清理 JSON Schema
|
||||||
|
cleanedParams := cleanToolSchema(params)
|
||||||
|
|
||||||
funcDecls = append(funcDecls, map[string]any{
|
funcDecls = append(funcDecls, map[string]any{
|
||||||
"name": name,
|
"name": name,
|
||||||
"description": desc,
|
"description": desc,
|
||||||
"parameters": params,
|
"parameters": cleanedParams,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2296,6 +2341,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段
|
||||||
|
func cleanToolSchema(schema any) any {
|
||||||
|
if schema == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := schema.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
cleaned := make(map[string]any)
|
||||||
|
for key, value := range v {
|
||||||
|
// 跳过不支持的字段
|
||||||
|
if key == "$schema" || key == "$id" || key == "$ref" ||
|
||||||
|
key == "additionalProperties" || key == "minLength" ||
|
||||||
|
key == "maxLength" || key == "minItems" || key == "maxItems" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 递归清理嵌套对象
|
||||||
|
cleaned[key] = cleanToolSchema(value)
|
||||||
|
}
|
||||||
|
// 规范化 type 字段为大写
|
||||||
|
if typeVal, ok := cleaned["type"].(string); ok {
|
||||||
|
cleaned["type"] = strings.ToUpper(typeVal)
|
||||||
|
}
|
||||||
|
return cleaned
|
||||||
|
case []any:
|
||||||
|
cleaned := make([]any, len(v))
|
||||||
|
for i, item := range v {
|
||||||
|
cleaned[i] = cleanToolSchema(item)
|
||||||
|
}
|
||||||
|
return cleaned
|
||||||
|
default:
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
|
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
|
||||||
out := make(map[string]any)
|
out := make(map[string]any)
|
||||||
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {
|
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {
|
||||||
|
|||||||
128
backend/internal/service/gemini_messages_compat_service_test.go
Normal file
128
backend/internal/service/gemini_messages_compat_service_test.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
||||||
|
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tools any
|
||||||
|
expectedLen int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Standard tools",
|
||||||
|
tools: []any{
|
||||||
|
map[string]any{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info",
|
||||||
|
"input_schema": map[string]any{"type": "object"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLen: 1,
|
||||||
|
description: "标准工具格式应该正常转换",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Custom type tool (MCP format)",
|
||||||
|
tools: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "custom",
|
||||||
|
"name": "mcp_tool",
|
||||||
|
"custom": map[string]any{
|
||||||
|
"description": "MCP tool description",
|
||||||
|
"input_schema": map[string]any{"type": "object"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLen: 1,
|
||||||
|
description: "Custom类型工具应该从custom字段读取",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed standard and custom tools",
|
||||||
|
tools: []any{
|
||||||
|
map[string]any{
|
||||||
|
"name": "standard_tool",
|
||||||
|
"description": "Standard",
|
||||||
|
"input_schema": map[string]any{"type": "object"},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "custom",
|
||||||
|
"name": "custom_tool",
|
||||||
|
"custom": map[string]any{
|
||||||
|
"description": "Custom",
|
||||||
|
"input_schema": map[string]any{"type": "object"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLen: 1,
|
||||||
|
description: "混合工具应该都能正确转换",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Custom tool without custom field",
|
||||||
|
tools: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "custom",
|
||||||
|
"name": "invalid_custom",
|
||||||
|
// 缺少 custom 字段
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLen: 0, // 应该被跳过
|
||||||
|
description: "缺少custom字段的custom工具应该被跳过",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := convertClaudeToolsToGeminiTools(tt.tools)
|
||||||
|
|
||||||
|
if tt.expectedLen == 0 {
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("%s: expected nil result, got %v", tt.description, result)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
t.Fatalf("%s: expected non-nil result", tt.description)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
toolDecl, ok := result[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsArr, _ := tt.tools.([]any)
|
||||||
|
expectedFuncCount := 0
|
||||||
|
for _, tool := range toolsArr {
|
||||||
|
toolMap, _ := tool.(map[string]any)
|
||||||
|
if toolMap["name"] != "" {
|
||||||
|
// 检查是否为有效的custom工具
|
||||||
|
if toolMap["type"] == "custom" {
|
||||||
|
if toolMap["custom"] != nil {
|
||||||
|
expectedFuncCount++
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
expectedFuncCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(funcDecls) != expectedFuncCount {
|
||||||
|
t.Errorf("%s: expected %d function declarations, got %d",
|
||||||
|
tt.description, expectedFuncCount, len(funcDecls))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,6 +25,16 @@ func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Acco
|
|||||||
return nil, errors.New("account not found")
|
return nil, errors.New("account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAccountRepoForGemini) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
|
||||||
|
var result []*Account
|
||||||
|
for _, id := range ids {
|
||||||
|
if acc, ok := m.accountsByID[id]; ok {
|
||||||
|
result = append(result, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAccountRepoForGemini) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
func (m *mockAccountRepoForGemini) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||||
if m.accountsByID == nil {
|
if m.accountsByID == nil {
|
||||||
return false, nil
|
return false, nil
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,6 +17,26 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TierAIPremium = "AI_PREMIUM"
|
||||||
|
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
|
||||||
|
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
|
||||||
|
TierFree = "FREE"
|
||||||
|
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
|
||||||
|
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
GB = 1024 * 1024 * 1024
|
||||||
|
TB = 1024 * GB
|
||||||
|
|
||||||
|
StorageTierUnlimited = 100 * TB // 100TB
|
||||||
|
StorageTierAIPremium = 2 * TB // 2TB
|
||||||
|
StorageTierStandard = 200 * GB // 200GB
|
||||||
|
StorageTierBasic = 100 * GB // 100GB
|
||||||
|
StorageTierFree = 15 * GB // 15GB
|
||||||
|
)
|
||||||
|
|
||||||
type GeminiOAuthService struct {
|
type GeminiOAuthService struct {
|
||||||
sessionStore *geminicli.SessionStore
|
sessionStore *geminicli.SessionStore
|
||||||
proxyRepo ProxyRepository
|
proxyRepo ProxyRepository
|
||||||
@@ -88,13 +109,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
|||||||
|
|
||||||
// OAuth client selection:
|
// OAuth client selection:
|
||||||
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
|
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
|
||||||
|
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
|
||||||
// - ai_studio: requires a user-provided OAuth client.
|
// - ai_studio: requires a user-provided OAuth client.
|
||||||
oauthCfg := geminicli.OAuthConfig{
|
oauthCfg := geminicli.OAuthConfig{
|
||||||
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
||||||
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
||||||
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
||||||
}
|
}
|
||||||
if oauthType == "code_assist" {
|
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||||
oauthCfg.ClientID = ""
|
oauthCfg.ClientID = ""
|
||||||
oauthCfg.ClientSecret = ""
|
oauthCfg.ClientSecret = ""
|
||||||
}
|
}
|
||||||
@@ -155,14 +177,152 @@ type GeminiExchangeCodeInput struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GeminiTokenInfo struct {
|
type GeminiTokenInfo struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
ExpiresAt int64 `json:"expires_at"`
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
Scope string `json:"scope,omitempty"`
|
Scope string `json:"scope,omitempty"`
|
||||||
ProjectID string `json:"project_id,omitempty"`
|
ProjectID string `json:"project_id,omitempty"`
|
||||||
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
||||||
|
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
|
||||||
|
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateTierID validates tier_id format and length
|
||||||
|
func validateTierID(tierID string) error {
|
||||||
|
if tierID == "" {
|
||||||
|
return nil // Empty is allowed
|
||||||
|
}
|
||||||
|
if len(tierID) > 64 {
|
||||||
|
return fmt.Errorf("tier_id exceeds maximum length of 64 characters")
|
||||||
|
}
|
||||||
|
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
|
||||||
|
if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) {
|
||||||
|
return fmt.Errorf("tier_id contains invalid characters")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
|
||||||
|
// Prioritizes IsDefault tier, falls back to first non-empty tier
|
||||||
|
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
|
||||||
|
tierID := "LEGACY"
|
||||||
|
// First pass: look for default tier
|
||||||
|
for _, tier := range allowedTiers {
|
||||||
|
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
|
||||||
|
tierID = strings.TrimSpace(tier.ID)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Second pass: if still LEGACY, take first non-empty tier
|
||||||
|
if tierID == "LEGACY" {
|
||||||
|
for _, tier := range allowedTiers {
|
||||||
|
if strings.TrimSpace(tier.ID) != "" {
|
||||||
|
tierID = strings.TrimSpace(tier.ID)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tierID
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferGoogleOneTier infers Google One tier from Drive storage limit
|
||||||
|
func inferGoogleOneTier(storageBytes int64) string {
|
||||||
|
if storageBytes <= 0 {
|
||||||
|
return TierGoogleOneUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
if storageBytes > StorageTierUnlimited {
|
||||||
|
return TierGoogleOneUnlimited
|
||||||
|
}
|
||||||
|
if storageBytes >= StorageTierAIPremium {
|
||||||
|
return TierAIPremium
|
||||||
|
}
|
||||||
|
if storageBytes >= StorageTierStandard {
|
||||||
|
return TierGoogleOneStandard
|
||||||
|
}
|
||||||
|
if storageBytes >= StorageTierBasic {
|
||||||
|
return TierGoogleOneBasic
|
||||||
|
}
|
||||||
|
if storageBytes >= StorageTierFree {
|
||||||
|
return TierFree
|
||||||
|
}
|
||||||
|
return TierGoogleOneUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchGoogleOneTier fetches Google One tier from Drive API
|
||||||
|
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
|
||||||
|
driveClient := geminicli.NewDriveClient()
|
||||||
|
|
||||||
|
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
// Check if it's a 403 (scope not granted)
|
||||||
|
if strings.Contains(err.Error(), "status 403") {
|
||||||
|
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
|
||||||
|
return TierGoogleOneUnknown, nil, err
|
||||||
|
}
|
||||||
|
// Other errors
|
||||||
|
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
|
||||||
|
return TierGoogleOneUnknown, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tierID := inferGoogleOneTier(storageInfo.Limit)
|
||||||
|
return tierID, storageInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier
|
||||||
|
func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
|
||||||
|
ctx context.Context,
|
||||||
|
account *Account,
|
||||||
|
) (tierID string, extra map[string]any, credentials map[string]any, err error) {
|
||||||
|
if account == nil {
|
||||||
|
return "", nil, nil, fmt.Errorf("account is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证账号类型
|
||||||
|
oauthType, ok := account.Credentials["oauth_type"].(string)
|
||||||
|
if !ok || oauthType != "google_one" {
|
||||||
|
return "", nil, nil, fmt.Errorf("not a google_one OAuth account")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 access_token
|
||||||
|
accessToken, ok := account.Credentials["access_token"].(string)
|
||||||
|
if !ok || accessToken == "" {
|
||||||
|
return "", nil, nil, fmt.Errorf("missing access_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 proxy URL
|
||||||
|
var proxyURL string
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用 Drive API
|
||||||
|
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 extra 数据(保留原有 extra 字段)
|
||||||
|
extra = make(map[string]any)
|
||||||
|
for k, v := range account.Extra {
|
||||||
|
extra[k] = v
|
||||||
|
}
|
||||||
|
if storageInfo != nil {
|
||||||
|
extra["drive_storage_limit"] = storageInfo.Limit
|
||||||
|
extra["drive_storage_usage"] = storageInfo.Usage
|
||||||
|
extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 credentials 数据
|
||||||
|
credentials = make(map[string]any)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
credentials[k] = v
|
||||||
|
}
|
||||||
|
credentials["tier_id"] = tierID
|
||||||
|
|
||||||
|
return tierID, extra, credentials, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||||
@@ -219,26 +379,78 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
|||||||
sessionProjectID := strings.TrimSpace(session.ProjectID)
|
sessionProjectID := strings.TrimSpace(session.ProjectID)
|
||||||
s.sessionStore.Delete(input.SessionID)
|
s.sessionStore.Delete(input.SessionID)
|
||||||
|
|
||||||
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
|
// 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差)
|
||||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
|
||||||
|
const safetyWindow = 300 // 5 minutes
|
||||||
|
const minTTL = 30 // minimum 30 seconds
|
||||||
|
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
|
||||||
|
minExpiresAt := time.Now().Unix() + minTTL
|
||||||
|
if expiresAt < minExpiresAt {
|
||||||
|
expiresAt = minExpiresAt
|
||||||
|
}
|
||||||
|
|
||||||
projectID := sessionProjectID
|
projectID := sessionProjectID
|
||||||
|
var tierID string
|
||||||
|
|
||||||
// 对于 code_assist 模式,project_id 是必需的
|
// 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API
|
||||||
|
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别
|
||||||
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
||||||
if oauthType == "code_assist" {
|
switch oauthType {
|
||||||
|
case "code_assist":
|
||||||
if projectID == "" {
|
if projectID == "" {
|
||||||
var err error
|
var err error
|
||||||
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 记录警告但不阻断流程,允许后续补充 project_id
|
// 记录警告但不阻断流程,允许后续补充 project_id
|
||||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID
|
||||||
|
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
|
||||||
|
} else {
|
||||||
|
tierID = fetchedTierID
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(projectID) == "" {
|
if strings.TrimSpace(projectID) == "" {
|
||||||
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
|
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
|
||||||
}
|
}
|
||||||
|
// tierID 缺失时使用默认值
|
||||||
|
if tierID == "" {
|
||||||
|
tierID = "LEGACY"
|
||||||
|
}
|
||||||
|
case "google_one":
|
||||||
|
// Attempt to fetch Drive storage tier
|
||||||
|
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
// Log warning but don't block - use fallback
|
||||||
|
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
|
||||||
|
tierID = TierGoogleOneUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store Drive info in extra field for caching
|
||||||
|
if storageInfo != nil {
|
||||||
|
tokenInfo := &GeminiTokenInfo{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
TokenType: tokenResp.TokenType,
|
||||||
|
ExpiresIn: tokenResp.ExpiresIn,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
Scope: tokenResp.Scope,
|
||||||
|
ProjectID: projectID,
|
||||||
|
TierID: tierID,
|
||||||
|
OAuthType: oauthType,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"drive_storage_limit": storageInfo.Limit,
|
||||||
|
"drive_storage_usage": storageInfo.Usage,
|
||||||
|
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return tokenInfo, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
// ai_studio 模式不设置 tierID,保持为空
|
||||||
|
|
||||||
return &GeminiTokenInfo{
|
return &GeminiTokenInfo{
|
||||||
AccessToken: tokenResp.AccessToken,
|
AccessToken: tokenResp.AccessToken,
|
||||||
@@ -248,6 +460,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
|||||||
ExpiresAt: expiresAt,
|
ExpiresAt: expiresAt,
|
||||||
Scope: tokenResp.Scope,
|
Scope: tokenResp.Scope,
|
||||||
ProjectID: projectID,
|
ProjectID: projectID,
|
||||||
|
TierID: tierID,
|
||||||
OAuthType: oauthType,
|
OAuthType: oauthType,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -266,8 +479,15 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres
|
|||||||
|
|
||||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
|
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
|
// 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差)
|
||||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
|
||||||
|
const safetyWindow = 300 // 5 minutes
|
||||||
|
const minTTL = 30 // minimum 30 seconds
|
||||||
|
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
|
||||||
|
minExpiresAt := time.Now().Unix() + minTTL
|
||||||
|
if expiresAt < minExpiresAt {
|
||||||
|
expiresAt = minExpiresAt
|
||||||
|
}
|
||||||
return &GeminiTokenInfo{
|
return &GeminiTokenInfo{
|
||||||
AccessToken: tokenResp.AccessToken,
|
AccessToken: tokenResp.AccessToken,
|
||||||
RefreshToken: tokenResp.RefreshToken,
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
@@ -354,18 +574,75 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
|||||||
tokenInfo.ProjectID = existingProjectID
|
tokenInfo.ProjectID = existingProjectID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 尝试从账号凭证获取 tierID(向后兼容)
|
||||||
|
existingTierID := strings.TrimSpace(account.GetCredential("tier_id"))
|
||||||
|
|
||||||
// For Code Assist, project_id is required. Auto-detect if missing.
|
// For Code Assist, project_id is required. Auto-detect if missing.
|
||||||
// For AI Studio OAuth, project_id is optional and should not block refresh.
|
// For AI Studio OAuth, project_id is optional and should not block refresh.
|
||||||
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
|
switch oauthType {
|
||||||
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
case "code_assist":
|
||||||
if err != nil {
|
// 先设置默认值或保留旧值,确保 tier_id 始终有值
|
||||||
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
|
if existingTierID != "" {
|
||||||
|
tokenInfo.TierID = existingTierID
|
||||||
|
} else {
|
||||||
|
tokenInfo.TierID = "LEGACY" // 默认值
|
||||||
}
|
}
|
||||||
projectID = strings.TrimSpace(projectID)
|
|
||||||
if projectID == "" {
|
// 尝试自动探测 project_id 和 tier_id
|
||||||
|
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
|
||||||
|
if needDetect {
|
||||||
|
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err)
|
||||||
|
} else {
|
||||||
|
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
|
||||||
|
tokenInfo.ProjectID = projectID
|
||||||
|
}
|
||||||
|
// 只有当原来没有 tier_id 且探测成功时才更新
|
||||||
|
if existingTierID == "" && tierID != "" {
|
||||||
|
tokenInfo.TierID = tierID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(tokenInfo.ProjectID) == "" {
|
||||||
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
|
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
|
||||||
}
|
}
|
||||||
tokenInfo.ProjectID = projectID
|
case "google_one":
|
||||||
|
// Check if tier cache is stale (> 24 hours)
|
||||||
|
needsRefresh := true
|
||||||
|
if account.Extra != nil {
|
||||||
|
if updatedAtStr, ok := account.Extra["drive_tier_updated_at"].(string); ok {
|
||||||
|
if updatedAt, err := time.Parse(time.RFC3339, updatedAtStr); err == nil {
|
||||||
|
if time.Since(updatedAt) <= 24*time.Hour {
|
||||||
|
needsRefresh = false
|
||||||
|
// Use cached tier
|
||||||
|
if existingTierID != "" {
|
||||||
|
tokenInfo.TierID = existingTierID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsRefresh {
|
||||||
|
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
|
||||||
|
if err == nil && storageInfo != nil {
|
||||||
|
tokenInfo.TierID = tierID
|
||||||
|
tokenInfo.Extra = map[string]any{
|
||||||
|
"drive_storage_limit": storageInfo.Limit,
|
||||||
|
"drive_storage_usage": storageInfo.Usage,
|
||||||
|
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Fallback to cached or unknown
|
||||||
|
if existingTierID != "" {
|
||||||
|
tokenInfo.TierID = existingTierID
|
||||||
|
} else {
|
||||||
|
tokenInfo.TierID = TierGoogleOneUnknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokenInfo, nil
|
return tokenInfo, nil
|
||||||
@@ -388,9 +665,22 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
|
|||||||
if tokenInfo.ProjectID != "" {
|
if tokenInfo.ProjectID != "" {
|
||||||
creds["project_id"] = tokenInfo.ProjectID
|
creds["project_id"] = tokenInfo.ProjectID
|
||||||
}
|
}
|
||||||
|
if tokenInfo.TierID != "" {
|
||||||
|
// Validate tier_id before storing
|
||||||
|
if err := validateTierID(tokenInfo.TierID); err == nil {
|
||||||
|
creds["tier_id"] = tokenInfo.TierID
|
||||||
|
}
|
||||||
|
// Silently skip invalid tier_id (don't block account creation)
|
||||||
|
}
|
||||||
if tokenInfo.OAuthType != "" {
|
if tokenInfo.OAuthType != "" {
|
||||||
creds["oauth_type"] = tokenInfo.OAuthType
|
creds["oauth_type"] = tokenInfo.OAuthType
|
||||||
}
|
}
|
||||||
|
// Store extra metadata (Drive info) if present
|
||||||
|
if len(tokenInfo.Extra) > 0 {
|
||||||
|
for k, v := range tokenInfo.Extra {
|
||||||
|
creds[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
return creds
|
return creds
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -398,33 +688,22 @@ func (s *GeminiOAuthService) Stop() {
|
|||||||
s.sessionStore.Stop()
|
s.sessionStore.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) {
|
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) {
|
||||||
if s.codeAssist == nil {
|
if s.codeAssist == nil {
|
||||||
return "", errors.New("code assist client not configured")
|
return "", "", errors.New("code assist client not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
|
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
|
||||||
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
|
|
||||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
|
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
|
||||||
tierID := "LEGACY"
|
tierID := "LEGACY"
|
||||||
if loadResp != nil {
|
if loadResp != nil {
|
||||||
for _, tier := range loadResp.AllowedTiers {
|
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||||
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
|
}
|
||||||
tierID = strings.TrimSpace(tier.ID)
|
|
||||||
break
|
// If LoadCodeAssist returned a project, use it
|
||||||
}
|
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
|
||||||
}
|
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
|
||||||
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
|
|
||||||
for _, tier := range loadResp.AllowedTiers {
|
|
||||||
if strings.TrimSpace(tier.ID) != "" {
|
|
||||||
tierID = strings.TrimSpace(tier.ID)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &geminicli.OnboardUserRequest{
|
req := &geminicli.OnboardUserRequest{
|
||||||
@@ -443,39 +722,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
|||||||
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
|
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
|
||||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||||
return strings.TrimSpace(fallback), nil
|
return strings.TrimSpace(fallback), tierID, nil
|
||||||
}
|
}
|
||||||
return "", err
|
return "", tierID, err
|
||||||
}
|
}
|
||||||
if resp.Done {
|
if resp.Done {
|
||||||
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
|
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
|
||||||
switch v := resp.Response.CloudAICompanionProject.(type) {
|
switch v := resp.Response.CloudAICompanionProject.(type) {
|
||||||
case string:
|
case string:
|
||||||
return strings.TrimSpace(v), nil
|
return strings.TrimSpace(v), tierID, nil
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
if id, ok := v["id"].(string); ok {
|
if id, ok := v["id"].(string); ok {
|
||||||
return strings.TrimSpace(id), nil
|
return strings.TrimSpace(id), tierID, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||||
return strings.TrimSpace(fallback), nil
|
return strings.TrimSpace(fallback), tierID, nil
|
||||||
}
|
}
|
||||||
return "", errors.New("onboardUser completed but no project_id returned")
|
return "", tierID, errors.New("onboardUser completed but no project_id returned")
|
||||||
}
|
}
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||||
return strings.TrimSpace(fallback), nil
|
return strings.TrimSpace(fallback), tierID, nil
|
||||||
}
|
}
|
||||||
if loadErr != nil {
|
if loadErr != nil {
|
||||||
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
|
return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
|
||||||
}
|
}
|
||||||
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
|
return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
|
||||||
}
|
}
|
||||||
|
|
||||||
type googleCloudProject struct {
|
type googleCloudProject struct {
|
||||||
|
|||||||
51
backend/internal/service/gemini_oauth_service_test.go
Normal file
51
backend/internal/service/gemini_oauth_service_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestInferGoogleOneTier(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
storageBytes int64
|
||||||
|
expectedTier string
|
||||||
|
}{
|
||||||
|
{"Negative storage", -1, TierGoogleOneUnknown},
|
||||||
|
{"Zero storage", 0, TierGoogleOneUnknown},
|
||||||
|
|
||||||
|
// Free tier boundary (15GB)
|
||||||
|
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
|
||||||
|
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
|
||||||
|
{"Free tier (15GB)", StorageTierFree, TierFree},
|
||||||
|
|
||||||
|
// Basic tier boundary (100GB)
|
||||||
|
{"Between free and basic", 50 * GB, TierFree},
|
||||||
|
{"Just below basic tier", StorageTierBasic - 1, TierFree},
|
||||||
|
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
|
||||||
|
|
||||||
|
// Standard tier boundary (200GB)
|
||||||
|
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
|
||||||
|
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
|
||||||
|
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
|
||||||
|
|
||||||
|
// AI Premium tier boundary (2TB)
|
||||||
|
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
|
||||||
|
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
|
||||||
|
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
|
||||||
|
|
||||||
|
// Unlimited tier boundary (> 100TB)
|
||||||
|
{"Between premium and unlimited", 50 * TB, TierAIPremium},
|
||||||
|
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
|
||||||
|
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
|
||||||
|
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
|
||||||
|
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := inferGoogleOneTier(tt.storageBytes)
|
||||||
|
if result != tt.expectedTier {
|
||||||
|
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
|
||||||
|
tt.storageBytes, result, tt.expectedTier)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
268
backend/internal/service/gemini_quota.go
Normal file
268
backend/internal/service/gemini_quota.go
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type geminiModelClass string
|
||||||
|
|
||||||
|
const (
|
||||||
|
geminiModelPro geminiModelClass = "pro"
|
||||||
|
geminiModelFlash geminiModelClass = "flash"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GeminiDailyQuota struct {
|
||||||
|
ProRPD int64
|
||||||
|
FlashRPD int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiTierPolicy struct {
|
||||||
|
Quota GeminiDailyQuota
|
||||||
|
Cooldown time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiQuotaPolicy struct {
|
||||||
|
tiers map[string]GeminiTierPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiUsageTotals struct {
|
||||||
|
ProRequests int64
|
||||||
|
FlashRequests int64
|
||||||
|
ProTokens int64
|
||||||
|
FlashTokens int64
|
||||||
|
ProCost float64
|
||||||
|
FlashCost float64
|
||||||
|
}
|
||||||
|
|
||||||
|
const geminiQuotaCacheTTL = time.Minute
|
||||||
|
|
||||||
|
type geminiQuotaOverrides struct {
|
||||||
|
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiQuotaService struct {
|
||||||
|
cfg *config.Config
|
||||||
|
settingRepo SettingRepository
|
||||||
|
mu sync.Mutex
|
||||||
|
cachedAt time.Time
|
||||||
|
policy *GeminiQuotaPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGeminiQuotaService(cfg *config.Config, settingRepo SettingRepository) *GeminiQuotaService {
|
||||||
|
return &GeminiQuotaService{
|
||||||
|
cfg: cfg,
|
||||||
|
settingRepo: settingRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
|
||||||
|
if s == nil {
|
||||||
|
return newGeminiQuotaPolicy()
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.policy != nil && now.Sub(s.cachedAt) < geminiQuotaCacheTTL {
|
||||||
|
policy := s.policy
|
||||||
|
s.mu.Unlock()
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
policy := newGeminiQuotaPolicy()
|
||||||
|
if s.cfg != nil {
|
||||||
|
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
|
||||||
|
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
|
||||||
|
var overrides geminiQuotaOverrides
|
||||||
|
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
|
||||||
|
log.Printf("gemini quota: parse config policy failed: %v", err)
|
||||||
|
} else {
|
||||||
|
policy.ApplyOverrides(overrides.Tiers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.settingRepo != nil {
|
||||||
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyGeminiQuotaPolicy)
|
||||||
|
if err != nil && !errors.Is(err, ErrSettingNotFound) {
|
||||||
|
log.Printf("gemini quota: load setting failed: %v", err)
|
||||||
|
} else if strings.TrimSpace(value) != "" {
|
||||||
|
var overrides geminiQuotaOverrides
|
||||||
|
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
|
||||||
|
log.Printf("gemini quota: parse setting failed: %v", err)
|
||||||
|
} else {
|
||||||
|
policy.ApplyOverrides(overrides.Tiers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.policy = policy
|
||||||
|
s.cachedAt = now
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
|
||||||
|
if account == nil || !account.IsGeminiCodeAssist() {
|
||||||
|
return GeminiDailyQuota{}, false
|
||||||
|
}
|
||||||
|
policy := s.Policy(ctx)
|
||||||
|
return policy.QuotaForTier(account.GeminiTierID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
|
||||||
|
policy := s.Policy(ctx)
|
||||||
|
return policy.CooldownForTier(tierID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
|
||||||
|
return &GeminiQuotaPolicy{
|
||||||
|
tiers: map[string]GeminiTierPolicy{
|
||||||
|
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
|
||||||
|
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
|
||||||
|
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuotaConfig) {
|
||||||
|
if p == nil || len(tiers) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for rawID, override := range tiers {
|
||||||
|
tierID := normalizeGeminiTierID(rawID)
|
||||||
|
if tierID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
policy, ok := p.tiers[tierID]
|
||||||
|
if !ok {
|
||||||
|
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
|
||||||
|
}
|
||||||
|
if override.ProRPD != nil {
|
||||||
|
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
|
||||||
|
}
|
||||||
|
if override.FlashRPD != nil {
|
||||||
|
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
|
||||||
|
}
|
||||||
|
if override.CooldownMinutes != nil {
|
||||||
|
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
|
||||||
|
policy.Cooldown = time.Duration(minutes) * time.Minute
|
||||||
|
}
|
||||||
|
p.tiers[tierID] = policy
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
|
||||||
|
policy, ok := p.policyForTier(tierID)
|
||||||
|
if !ok {
|
||||||
|
return GeminiDailyQuota{}, false
|
||||||
|
}
|
||||||
|
return policy.Quota, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GeminiQuotaPolicy) CooldownForTier(tierID string) time.Duration {
|
||||||
|
policy, ok := p.policyForTier(tierID)
|
||||||
|
if ok && policy.Cooldown > 0 {
|
||||||
|
return policy.Cooldown
|
||||||
|
}
|
||||||
|
return 5 * time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool) {
|
||||||
|
if p == nil {
|
||||||
|
return GeminiTierPolicy{}, false
|
||||||
|
}
|
||||||
|
normalized := normalizeGeminiTierID(tierID)
|
||||||
|
if normalized == "" {
|
||||||
|
normalized = "LEGACY"
|
||||||
|
}
|
||||||
|
if policy, ok := p.tiers[normalized]; ok {
|
||||||
|
return policy, true
|
||||||
|
}
|
||||||
|
policy, ok := p.tiers["LEGACY"]
|
||||||
|
return policy, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeGeminiTierID(tierID string) string {
|
||||||
|
return strings.ToUpper(strings.TrimSpace(tierID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func clampGeminiQuotaInt64(value int64) int64 {
|
||||||
|
if value < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func clampGeminiQuotaInt(value int) int {
|
||||||
|
if value < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiCooldownForTier(tierID string) time.Duration {
|
||||||
|
policy := newGeminiQuotaPolicy()
|
||||||
|
return policy.CooldownForTier(tierID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiModelClassFromName(model string) geminiModelClass {
|
||||||
|
name := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
|
||||||
|
return geminiModelFlash
|
||||||
|
}
|
||||||
|
return geminiModelPro
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiAggregateUsage(stats []usagestats.ModelStat) GeminiUsageTotals {
|
||||||
|
var totals GeminiUsageTotals
|
||||||
|
for _, stat := range stats {
|
||||||
|
switch geminiModelClassFromName(stat.Model) {
|
||||||
|
case geminiModelFlash:
|
||||||
|
totals.FlashRequests += stat.Requests
|
||||||
|
totals.FlashTokens += stat.TotalTokens
|
||||||
|
totals.FlashCost += stat.ActualCost
|
||||||
|
default:
|
||||||
|
totals.ProRequests += stat.Requests
|
||||||
|
totals.ProTokens += stat.TotalTokens
|
||||||
|
totals.ProCost += stat.ActualCost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return totals
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiQuotaLocation() *time.Location {
|
||||||
|
loc, err := time.LoadLocation("America/Los_Angeles")
|
||||||
|
if err != nil {
|
||||||
|
return time.FixedZone("PST", -8*3600)
|
||||||
|
}
|
||||||
|
return loc
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiDailyWindowStart(now time.Time) time.Time {
|
||||||
|
loc := geminiQuotaLocation()
|
||||||
|
localNow := now.In(loc)
|
||||||
|
return time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiDailyResetTime(now time.Time) time.Time {
|
||||||
|
loc := geminiQuotaLocation()
|
||||||
|
localNow := now.In(loc)
|
||||||
|
start := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
|
||||||
|
reset := start.Add(24 * time.Hour)
|
||||||
|
if !reset.After(localNow) {
|
||||||
|
reset = reset.Add(24 * time.Hour)
|
||||||
|
}
|
||||||
|
return reset
|
||||||
|
}
|
||||||
@@ -112,17 +112,21 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
|
detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
|
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
|
||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
detected = strings.TrimSpace(detected)
|
detected = strings.TrimSpace(detected)
|
||||||
|
tierID = strings.TrimSpace(tierID)
|
||||||
if detected != "" {
|
if detected != "" {
|
||||||
if account.Credentials == nil {
|
if account.Credentials == nil {
|
||||||
account.Credentials = make(map[string]any)
|
account.Credentials = make(map[string]any)
|
||||||
}
|
}
|
||||||
account.Credentials["project_id"] = detected
|
account.Credentials["project_id"] = detected
|
||||||
|
if tierID != "" {
|
||||||
|
account.Credentials["tier_id"] = tierID
|
||||||
|
}
|
||||||
_ = p.accountRepo.Update(ctx, account)
|
_ = p.accountRepo.Update(ctx, account)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -80,6 +81,7 @@ type OpenAIGatewayService struct {
|
|||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
concurrencyService *ConcurrencyService
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
@@ -95,6 +97,7 @@ func NewOpenAIGatewayService(
|
|||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
cache GatewayCache,
|
cache GatewayCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
|
concurrencyService *ConcurrencyService,
|
||||||
billingService *BillingService,
|
billingService *BillingService,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
@@ -108,6 +111,7 @@ func NewOpenAIGatewayService(
|
|||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
concurrencyService: concurrencyService,
|
||||||
billingService: billingService,
|
billingService: billingService,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
@@ -126,6 +130,14 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
|||||||
return hex.EncodeToString(hash[:])
|
return hex.EncodeToString(hash[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BindStickySession sets session -> account binding with standard TTL.
|
||||||
|
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||||
|
if sessionHash == "" || accountID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL)
|
||||||
|
}
|
||||||
|
|
||||||
// SelectAccount selects an OpenAI account with sticky session support
|
// SelectAccount selects an OpenAI account with sticky session support
|
||||||
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
||||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||||
@@ -218,6 +230,254 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
|||||||
return selected, nil
|
return selected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
||||||
|
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||||
|
cfg := s.schedulingConfig()
|
||||||
|
var stickyAccountID int64
|
||||||
|
if sessionHash != "" && s.cache != nil {
|
||||||
|
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil {
|
||||||
|
stickyAccountID = accountID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||||
|
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||||
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||||
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: account.ID,
|
||||||
|
MaxConcurrency: account.Concurrency,
|
||||||
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: account.ID,
|
||||||
|
MaxConcurrency: account.Concurrency,
|
||||||
|
Timeout: cfg.FallbackWaitTimeout,
|
||||||
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(accounts) == 0 {
|
||||||
|
return nil, errors.New("no available accounts")
|
||||||
|
}
|
||||||
|
|
||||||
|
isExcluded := func(accountID int64) bool {
|
||||||
|
if excludedIDs == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, excluded := excludedIDs[accountID]
|
||||||
|
return excluded
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Layer 1: Sticky session ============
|
||||||
|
if sessionHash != "" {
|
||||||
|
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||||
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||||
|
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||||
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: accountID,
|
||||||
|
MaxConcurrency: account.Concurrency,
|
||||||
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Layer 2: Load-aware selection ============
|
||||||
|
candidates := make([]*Account, 0, len(accounts))
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
if isExcluded(acc.ID) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidates = append(candidates, acc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil, errors.New("no available accounts")
|
||||||
|
}
|
||||||
|
|
||||||
|
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||||
|
for _, acc := range candidates {
|
||||||
|
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||||
|
ID: acc.ID,
|
||||||
|
MaxConcurrency: acc.Concurrency,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||||
|
if err != nil {
|
||||||
|
ordered := append([]*Account(nil), candidates...)
|
||||||
|
sortAccountsByPriorityAndLastUsed(ordered, false)
|
||||||
|
for _, acc := range ordered {
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
if sessionHash != "" {
|
||||||
|
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: acc,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
type accountWithLoad struct {
|
||||||
|
account *Account
|
||||||
|
loadInfo *AccountLoadInfo
|
||||||
|
}
|
||||||
|
var available []accountWithLoad
|
||||||
|
for _, acc := range candidates {
|
||||||
|
loadInfo := loadMap[acc.ID]
|
||||||
|
if loadInfo == nil {
|
||||||
|
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||||
|
}
|
||||||
|
if loadInfo.LoadRate < 100 {
|
||||||
|
available = append(available, accountWithLoad{
|
||||||
|
account: acc,
|
||||||
|
loadInfo: loadInfo,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(available) > 0 {
|
||||||
|
sort.SliceStable(available, func(i, j int) bool {
|
||||||
|
a, b := available[i], available[j]
|
||||||
|
if a.account.Priority != b.account.Priority {
|
||||||
|
return a.account.Priority < b.account.Priority
|
||||||
|
}
|
||||||
|
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||||
|
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||||
|
return true
|
||||||
|
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||||
|
return false
|
||||||
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, item := range available {
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
if sessionHash != "" {
|
||||||
|
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: item.account,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Layer 3: Fallback wait ============
|
||||||
|
sortAccountsByPriorityAndLastUsed(candidates, false)
|
||||||
|
for _, acc := range candidates {
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: acc,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: acc.ID,
|
||||||
|
MaxConcurrency: acc.Concurrency,
|
||||||
|
Timeout: cfg.FallbackWaitTimeout,
|
||||||
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("no available accounts")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
||||||
|
var accounts []Account
|
||||||
|
var err error
|
||||||
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||||
|
} else if groupID != nil {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
||||||
|
} else {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
|
}
|
||||||
|
return accounts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||||
|
if s.concurrencyService == nil {
|
||||||
|
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
|
||||||
|
}
|
||||||
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||||
|
if s.cfg != nil {
|
||||||
|
return s.cfg.Gateway.Scheduling
|
||||||
|
}
|
||||||
|
return config.GatewaySchedulingConfig{
|
||||||
|
StickySessionMaxWaiting: 3,
|
||||||
|
StickySessionWaitTimeout: 45 * time.Second,
|
||||||
|
FallbackWaitTimeout: 30 * time.Second,
|
||||||
|
FallbackMaxWaiting: 100,
|
||||||
|
LoadBatchEnabled: true,
|
||||||
|
SlotCleanupInterval: 30 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccessToken gets the access token for an OpenAI account
|
// GetAccessToken gets the access token for an OpenAI account
|
||||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||||
switch account.Type {
|
switch account.Type {
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
@@ -12,15 +14,30 @@ import (
|
|||||||
|
|
||||||
// RateLimitService 处理限流和过载状态管理
|
// RateLimitService 处理限流和过载状态管理
|
||||||
type RateLimitService struct {
|
type RateLimitService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
cfg *config.Config
|
usageRepo UsageLogRepository
|
||||||
|
cfg *config.Config
|
||||||
|
geminiQuotaService *GeminiQuotaService
|
||||||
|
usageCacheMu sync.RWMutex
|
||||||
|
usageCache map[int64]*geminiUsageCacheEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type geminiUsageCacheEntry struct {
|
||||||
|
windowStart time.Time
|
||||||
|
cachedAt time.Time
|
||||||
|
totals GeminiUsageTotals
|
||||||
|
}
|
||||||
|
|
||||||
|
const geminiPrecheckCacheTTL = time.Minute
|
||||||
|
|
||||||
// NewRateLimitService 创建RateLimitService实例
|
// NewRateLimitService 创建RateLimitService实例
|
||||||
func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *RateLimitService {
|
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService {
|
||||||
return &RateLimitService{
|
return &RateLimitService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
cfg: cfg,
|
usageRepo: usageRepo,
|
||||||
|
cfg: cfg,
|
||||||
|
geminiQuotaService: geminiQuotaService,
|
||||||
|
usageCache: make(map[int64]*geminiUsageCacheEntry),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +79,106 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PreCheckUsage proactively checks local quota before dispatching a request.
|
||||||
|
// Returns false when the account should be skipped.
|
||||||
|
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
|
||||||
|
if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if s.usageRepo == nil || s.geminiQuotaService == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
|
||||||
|
if !ok {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var limit int64
|
||||||
|
switch geminiModelClassFromName(requestedModel) {
|
||||||
|
case geminiModelFlash:
|
||||||
|
limit = quota.FlashRPD
|
||||||
|
default:
|
||||||
|
limit = quota.ProRPD
|
||||||
|
}
|
||||||
|
if limit <= 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
start := geminiDailyWindowStart(now)
|
||||||
|
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
|
||||||
|
if !ok {
|
||||||
|
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
totals = geminiAggregateUsage(stats)
|
||||||
|
s.setGeminiUsageTotals(account.ID, start, now, totals)
|
||||||
|
}
|
||||||
|
|
||||||
|
var used int64
|
||||||
|
switch geminiModelClassFromName(requestedModel) {
|
||||||
|
case geminiModelFlash:
|
||||||
|
used = totals.FlashRequests
|
||||||
|
default:
|
||||||
|
used = totals.ProRequests
|
||||||
|
}
|
||||||
|
|
||||||
|
if used >= limit {
|
||||||
|
resetAt := geminiDailyResetTime(now)
|
||||||
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||||
|
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||||
|
}
|
||||||
|
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) {
|
||||||
|
s.usageCacheMu.RLock()
|
||||||
|
defer s.usageCacheMu.RUnlock()
|
||||||
|
|
||||||
|
if s.usageCache == nil {
|
||||||
|
return GeminiUsageTotals{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, ok := s.usageCache[accountID]
|
||||||
|
if !ok || entry == nil {
|
||||||
|
return GeminiUsageTotals{}, false
|
||||||
|
}
|
||||||
|
if !entry.windowStart.Equal(windowStart) {
|
||||||
|
return GeminiUsageTotals{}, false
|
||||||
|
}
|
||||||
|
if now.Sub(entry.cachedAt) >= geminiPrecheckCacheTTL {
|
||||||
|
return GeminiUsageTotals{}, false
|
||||||
|
}
|
||||||
|
return entry.totals, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RateLimitService) setGeminiUsageTotals(accountID int64, windowStart, now time.Time, totals GeminiUsageTotals) {
|
||||||
|
s.usageCacheMu.Lock()
|
||||||
|
defer s.usageCacheMu.Unlock()
|
||||||
|
if s.usageCache == nil {
|
||||||
|
s.usageCache = make(map[int64]*geminiUsageCacheEntry)
|
||||||
|
}
|
||||||
|
s.usageCache[accountID] = &geminiUsageCacheEntry{
|
||||||
|
windowStart: windowStart,
|
||||||
|
cachedAt: now,
|
||||||
|
totals: totals,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeminiCooldown returns the fallback cooldown duration for Gemini 429s based on tier.
|
||||||
|
func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) time.Duration {
|
||||||
|
if account == nil {
|
||||||
|
return 5 * time.Minute
|
||||||
|
}
|
||||||
|
return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID())
|
||||||
|
}
|
||||||
|
|
||||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
||||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ type User struct {
|
|||||||
ID int64
|
ID int64
|
||||||
Email string
|
Email string
|
||||||
Username string
|
Username string
|
||||||
Wechat string
|
|
||||||
Notes string
|
Notes string
|
||||||
PasswordHash string
|
PasswordHash string
|
||||||
Role string
|
Role string
|
||||||
|
|||||||
125
backend/internal/service/user_attribute.go
Normal file
125
backend/internal/service/user_attribute.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error definitions for user attribute operations
|
||||||
|
var (
|
||||||
|
ErrAttributeDefinitionNotFound = infraerrors.NotFound("ATTRIBUTE_DEFINITION_NOT_FOUND", "attribute definition not found")
|
||||||
|
ErrAttributeKeyExists = infraerrors.Conflict("ATTRIBUTE_KEY_EXISTS", "attribute key already exists")
|
||||||
|
ErrInvalidAttributeType = infraerrors.BadRequest("INVALID_ATTRIBUTE_TYPE", "invalid attribute type")
|
||||||
|
ErrAttributeValidationFailed = infraerrors.BadRequest("ATTRIBUTE_VALIDATION_FAILED", "attribute value validation failed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeType represents supported attribute types
|
||||||
|
type UserAttributeType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AttributeTypeText UserAttributeType = "text"
|
||||||
|
AttributeTypeTextarea UserAttributeType = "textarea"
|
||||||
|
AttributeTypeNumber UserAttributeType = "number"
|
||||||
|
AttributeTypeEmail UserAttributeType = "email"
|
||||||
|
AttributeTypeURL UserAttributeType = "url"
|
||||||
|
AttributeTypeDate UserAttributeType = "date"
|
||||||
|
AttributeTypeSelect UserAttributeType = "select"
|
||||||
|
AttributeTypeMultiSelect UserAttributeType = "multi_select"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeOption represents a select option for select/multi_select types
|
||||||
|
type UserAttributeOption struct {
|
||||||
|
Value string `json:"value"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValidation represents validation rules for an attribute
|
||||||
|
type UserAttributeValidation struct {
|
||||||
|
MinLength *int `json:"min_length,omitempty"`
|
||||||
|
MaxLength *int `json:"max_length,omitempty"`
|
||||||
|
Min *int `json:"min,omitempty"`
|
||||||
|
Max *int `json:"max,omitempty"`
|
||||||
|
Pattern *string `json:"pattern,omitempty"`
|
||||||
|
Message *string `json:"message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeDefinition represents a custom attribute definition
|
||||||
|
type UserAttributeDefinition struct {
|
||||||
|
ID int64
|
||||||
|
Key string
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Type UserAttributeType
|
||||||
|
Options []UserAttributeOption
|
||||||
|
Required bool
|
||||||
|
Validation UserAttributeValidation
|
||||||
|
Placeholder string
|
||||||
|
DisplayOrder int
|
||||||
|
Enabled bool
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValue represents a user's attribute value
|
||||||
|
type UserAttributeValue struct {
|
||||||
|
ID int64
|
||||||
|
UserID int64
|
||||||
|
AttributeID int64
|
||||||
|
Value string
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAttributeDefinitionInput for creating new definition
|
||||||
|
type CreateAttributeDefinitionInput struct {
|
||||||
|
Key string
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Type UserAttributeType
|
||||||
|
Options []UserAttributeOption
|
||||||
|
Required bool
|
||||||
|
Validation UserAttributeValidation
|
||||||
|
Placeholder string
|
||||||
|
Enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAttributeDefinitionInput for updating definition
|
||||||
|
type UpdateAttributeDefinitionInput struct {
|
||||||
|
Name *string
|
||||||
|
Description *string
|
||||||
|
Type *UserAttributeType
|
||||||
|
Options *[]UserAttributeOption
|
||||||
|
Required *bool
|
||||||
|
Validation *UserAttributeValidation
|
||||||
|
Placeholder *string
|
||||||
|
Enabled *bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserAttributeInput for updating a single attribute value
|
||||||
|
type UpdateUserAttributeInput struct {
|
||||||
|
AttributeID int64
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeDefinitionRepository interface for attribute definition persistence
|
||||||
|
type UserAttributeDefinitionRepository interface {
|
||||||
|
Create(ctx context.Context, def *UserAttributeDefinition) error
|
||||||
|
GetByID(ctx context.Context, id int64) (*UserAttributeDefinition, error)
|
||||||
|
GetByKey(ctx context.Context, key string) (*UserAttributeDefinition, error)
|
||||||
|
Update(ctx context.Context, def *UserAttributeDefinition) error
|
||||||
|
Delete(ctx context.Context, id int64) error
|
||||||
|
List(ctx context.Context, enabledOnly bool) ([]UserAttributeDefinition, error)
|
||||||
|
UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error
|
||||||
|
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAttributeValueRepository interface for user attribute value persistence
|
||||||
|
type UserAttributeValueRepository interface {
|
||||||
|
GetByUserID(ctx context.Context, userID int64) ([]UserAttributeValue, error)
|
||||||
|
GetByUserIDs(ctx context.Context, userIDs []int64) ([]UserAttributeValue, error)
|
||||||
|
UpsertBatch(ctx context.Context, userID int64, values []UpdateUserAttributeInput) error
|
||||||
|
DeleteByAttributeID(ctx context.Context, attributeID int64) error
|
||||||
|
DeleteByUserID(ctx context.Context, userID int64) error
|
||||||
|
}
|
||||||
295
backend/internal/service/user_attribute_service.go
Normal file
295
backend/internal/service/user_attribute_service.go
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAttributeService handles attribute management
|
||||||
|
type UserAttributeService struct {
|
||||||
|
defRepo UserAttributeDefinitionRepository
|
||||||
|
valueRepo UserAttributeValueRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserAttributeService creates a new service instance
|
||||||
|
func NewUserAttributeService(
|
||||||
|
defRepo UserAttributeDefinitionRepository,
|
||||||
|
valueRepo UserAttributeValueRepository,
|
||||||
|
) *UserAttributeService {
|
||||||
|
return &UserAttributeService{
|
||||||
|
defRepo: defRepo,
|
||||||
|
valueRepo: valueRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateDefinition creates a new attribute definition
|
||||||
|
func (s *UserAttributeService) CreateDefinition(ctx context.Context, input CreateAttributeDefinitionInput) (*UserAttributeDefinition, error) {
|
||||||
|
// Validate type
|
||||||
|
if !isValidAttributeType(input.Type) {
|
||||||
|
return nil, ErrInvalidAttributeType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if key exists
|
||||||
|
exists, err := s.defRepo.ExistsByKey(ctx, input.Key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("check key exists: %w", err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
return nil, ErrAttributeKeyExists
|
||||||
|
}
|
||||||
|
|
||||||
|
def := &UserAttributeDefinition{
|
||||||
|
Key: input.Key,
|
||||||
|
Name: input.Name,
|
||||||
|
Description: input.Description,
|
||||||
|
Type: input.Type,
|
||||||
|
Options: input.Options,
|
||||||
|
Required: input.Required,
|
||||||
|
Validation: input.Validation,
|
||||||
|
Placeholder: input.Placeholder,
|
||||||
|
Enabled: input.Enabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.defRepo.Create(ctx, def); err != nil {
|
||||||
|
return nil, fmt.Errorf("create definition: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return def, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefinition retrieves a definition by ID
|
||||||
|
func (s *UserAttributeService) GetDefinition(ctx context.Context, id int64) (*UserAttributeDefinition, error) {
|
||||||
|
return s.defRepo.GetByID(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListDefinitions lists all definitions
|
||||||
|
func (s *UserAttributeService) ListDefinitions(ctx context.Context, enabledOnly bool) ([]UserAttributeDefinition, error) {
|
||||||
|
return s.defRepo.List(ctx, enabledOnly)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDefinition updates an existing definition
|
||||||
|
func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, input UpdateAttributeDefinitionInput) (*UserAttributeDefinition, error) {
|
||||||
|
def, err := s.defRepo.GetByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Name != nil {
|
||||||
|
def.Name = *input.Name
|
||||||
|
}
|
||||||
|
if input.Description != nil {
|
||||||
|
def.Description = *input.Description
|
||||||
|
}
|
||||||
|
if input.Type != nil {
|
||||||
|
if !isValidAttributeType(*input.Type) {
|
||||||
|
return nil, ErrInvalidAttributeType
|
||||||
|
}
|
||||||
|
def.Type = *input.Type
|
||||||
|
}
|
||||||
|
if input.Options != nil {
|
||||||
|
def.Options = *input.Options
|
||||||
|
}
|
||||||
|
if input.Required != nil {
|
||||||
|
def.Required = *input.Required
|
||||||
|
}
|
||||||
|
if input.Validation != nil {
|
||||||
|
def.Validation = *input.Validation
|
||||||
|
}
|
||||||
|
if input.Placeholder != nil {
|
||||||
|
def.Placeholder = *input.Placeholder
|
||||||
|
}
|
||||||
|
if input.Enabled != nil {
|
||||||
|
def.Enabled = *input.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.defRepo.Update(ctx, def); err != nil {
|
||||||
|
return nil, fmt.Errorf("update definition: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return def, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDefinition soft-deletes a definition and hard-deletes associated values
|
||||||
|
func (s *UserAttributeService) DeleteDefinition(ctx context.Context, id int64) error {
|
||||||
|
// Check if definition exists
|
||||||
|
_, err := s.defRepo.GetByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// First delete all values (hard delete)
|
||||||
|
if err := s.valueRepo.DeleteByAttributeID(ctx, id); err != nil {
|
||||||
|
return fmt.Errorf("delete values: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then soft-delete the definition
|
||||||
|
if err := s.defRepo.Delete(ctx, id); err != nil {
|
||||||
|
return fmt.Errorf("delete definition: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReorderDefinitions updates display order for multiple definitions
|
||||||
|
func (s *UserAttributeService) ReorderDefinitions(ctx context.Context, orders map[int64]int) error {
|
||||||
|
return s.defRepo.UpdateDisplayOrders(ctx, orders)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserAttributes retrieves all attribute values for a user
|
||||||
|
func (s *UserAttributeService) GetUserAttributes(ctx context.Context, userID int64) ([]UserAttributeValue, error) {
|
||||||
|
return s.valueRepo.GetByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBatchUserAttributes retrieves attribute values for multiple users
|
||||||
|
// Returns a map of userID -> map of attributeID -> value
|
||||||
|
func (s *UserAttributeService) GetBatchUserAttributes(ctx context.Context, userIDs []int64) (map[int64]map[int64]string, error) {
|
||||||
|
values, err := s.valueRepo.GetByUserIDs(ctx, userIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[int64]map[int64]string)
|
||||||
|
for _, v := range values {
|
||||||
|
if result[v.UserID] == nil {
|
||||||
|
result[v.UserID] = make(map[int64]string)
|
||||||
|
}
|
||||||
|
result[v.UserID][v.AttributeID] = v.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserAttributes batch updates attribute values for a user
|
||||||
|
func (s *UserAttributeService) UpdateUserAttributes(ctx context.Context, userID int64, inputs []UpdateUserAttributeInput) error {
|
||||||
|
// Validate all values before updating
|
||||||
|
defs, err := s.defRepo.List(ctx, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list definitions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defMap := make(map[int64]*UserAttributeDefinition, len(defs))
|
||||||
|
for i := range defs {
|
||||||
|
defMap[defs[i].ID] = &defs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, input := range inputs {
|
||||||
|
def, ok := defMap[input.AttributeID]
|
||||||
|
if !ok {
|
||||||
|
return ErrAttributeDefinitionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.validateValue(def, input.Value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.valueRepo.UpsertBatch(ctx, userID, inputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateValue validates a value against its definition
|
||||||
|
func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value string) error {
|
||||||
|
// Skip validation for empty non-required fields
|
||||||
|
if value == "" && !def.Required {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required check
|
||||||
|
if def.Required && value == "" {
|
||||||
|
return validationError(fmt.Sprintf("%s is required", def.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
v := def.Validation
|
||||||
|
|
||||||
|
// String length validation
|
||||||
|
if v.MinLength != nil && len(value) < *v.MinLength {
|
||||||
|
return validationError(fmt.Sprintf("%s must be at least %d characters", def.Name, *v.MinLength))
|
||||||
|
}
|
||||||
|
if v.MaxLength != nil && len(value) > *v.MaxLength {
|
||||||
|
return validationError(fmt.Sprintf("%s must be at most %d characters", def.Name, *v.MaxLength))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number validation
|
||||||
|
if def.Type == AttributeTypeNumber && value != "" {
|
||||||
|
num, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return validationError(fmt.Sprintf("%s must be a number", def.Name))
|
||||||
|
}
|
||||||
|
if v.Min != nil && num < *v.Min {
|
||||||
|
return validationError(fmt.Sprintf("%s must be at least %d", def.Name, *v.Min))
|
||||||
|
}
|
||||||
|
if v.Max != nil && num > *v.Max {
|
||||||
|
return validationError(fmt.Sprintf("%s must be at most %d", def.Name, *v.Max))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern validation
|
||||||
|
if v.Pattern != nil && *v.Pattern != "" && value != "" {
|
||||||
|
re, err := regexp.Compile(*v.Pattern)
|
||||||
|
if err == nil && !re.MatchString(value) {
|
||||||
|
msg := def.Name + " format is invalid"
|
||||||
|
if v.Message != nil && *v.Message != "" {
|
||||||
|
msg = *v.Message
|
||||||
|
}
|
||||||
|
return validationError(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select validation
|
||||||
|
if def.Type == AttributeTypeSelect && value != "" {
|
||||||
|
found := false
|
||||||
|
for _, opt := range def.Options {
|
||||||
|
if opt.Value == value {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return validationError(fmt.Sprintf("%s: invalid option", def.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-select validation (stored as JSON array)
|
||||||
|
if def.Type == AttributeTypeMultiSelect && value != "" {
|
||||||
|
var values []string
|
||||||
|
if err := json.Unmarshal([]byte(value), &values); err != nil {
|
||||||
|
// Try comma-separated fallback
|
||||||
|
values = strings.Split(value, ",")
|
||||||
|
}
|
||||||
|
for _, val := range values {
|
||||||
|
val = strings.TrimSpace(val)
|
||||||
|
found := false
|
||||||
|
for _, opt := range def.Options {
|
||||||
|
if opt.Value == val {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return validationError(fmt.Sprintf("%s: invalid option %s", def.Name, val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validationError creates a validation error with a custom message
|
||||||
|
func validationError(msg string) error {
|
||||||
|
return infraerrors.BadRequest("ATTRIBUTE_VALIDATION_FAILED", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidAttributeType(t UserAttributeType) bool {
|
||||||
|
switch t {
|
||||||
|
case AttributeTypeText, AttributeTypeTextarea, AttributeTypeNumber,
|
||||||
|
AttributeTypeEmail, AttributeTypeURL, AttributeTypeDate,
|
||||||
|
AttributeTypeSelect, AttributeTypeMultiSelect:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -14,6 +14,14 @@ var (
|
|||||||
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
|
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UserListFilters contains all filter options for listing users
|
||||||
|
type UserListFilters struct {
|
||||||
|
Status string // User status filter
|
||||||
|
Role string // User role filter
|
||||||
|
Search string // Search in email, username
|
||||||
|
Attributes map[int64]string // Custom attribute filters: attributeID -> value
|
||||||
|
}
|
||||||
|
|
||||||
type UserRepository interface {
|
type UserRepository interface {
|
||||||
Create(ctx context.Context, user *User) error
|
Create(ctx context.Context, user *User) error
|
||||||
GetByID(ctx context.Context, id int64) (*User, error)
|
GetByID(ctx context.Context, id int64) (*User, error)
|
||||||
@@ -23,7 +31,7 @@ type UserRepository interface {
|
|||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
|
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
|
||||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error)
|
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
|
||||||
|
|
||||||
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
||||||
DeductBalance(ctx context.Context, id int64, amount float64) error
|
DeductBalance(ctx context.Context, id int64, amount float64) error
|
||||||
@@ -36,7 +44,6 @@ type UserRepository interface {
|
|||||||
type UpdateProfileRequest struct {
|
type UpdateProfileRequest struct {
|
||||||
Email *string `json:"email"`
|
Email *string `json:"email"`
|
||||||
Username *string `json:"username"`
|
Username *string `json:"username"`
|
||||||
Wechat *string `json:"wechat"`
|
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,10 +107,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
|||||||
user.Username = *req.Username
|
user.Username = *req.Username
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Wechat != nil {
|
|
||||||
user.Wechat = *req.Wechat
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Concurrency != nil {
|
if req.Concurrency != nil {
|
||||||
user.Concurrency = *req.Concurrency
|
user.Concurrency = *req.Concurrency
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,6 +73,15 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
|
||||||
|
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
|
||||||
|
svc := NewConcurrencyService(cache)
|
||||||
|
if cfg != nil {
|
||||||
|
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||||
|
}
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
// ProviderSet is the Wire provider set for all services
|
// ProviderSet is the Wire provider set for all services
|
||||||
var ProviderSet = wire.NewSet(
|
var ProviderSet = wire.NewSet(
|
||||||
// Core services
|
// Core services
|
||||||
@@ -94,6 +103,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewOAuthService,
|
NewOAuthService,
|
||||||
NewOpenAIOAuthService,
|
NewOpenAIOAuthService,
|
||||||
NewGeminiOAuthService,
|
NewGeminiOAuthService,
|
||||||
|
NewGeminiQuotaService,
|
||||||
NewAntigravityOAuthService,
|
NewAntigravityOAuthService,
|
||||||
NewGeminiTokenProvider,
|
NewGeminiTokenProvider,
|
||||||
NewGeminiMessagesCompatService,
|
NewGeminiMessagesCompatService,
|
||||||
@@ -107,7 +117,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
ProvideEmailQueueService,
|
ProvideEmailQueueService,
|
||||||
NewTurnstileService,
|
NewTurnstileService,
|
||||||
NewSubscriptionService,
|
NewSubscriptionService,
|
||||||
NewConcurrencyService,
|
ProvideConcurrencyService,
|
||||||
NewIdentityService,
|
NewIdentityService,
|
||||||
NewCRSSyncService,
|
NewCRSSyncService,
|
||||||
ProvideUpdateService,
|
ProvideUpdateService,
|
||||||
@@ -115,4 +125,5 @@ var ProviderSet = wire.NewSet(
|
|||||||
ProvideTimingWheelService,
|
ProvideTimingWheelService,
|
||||||
ProvideDeferredService,
|
ProvideDeferredService,
|
||||||
ProvideAntigravityQuotaRefresher,
|
ProvideAntigravityQuotaRefresher,
|
||||||
|
NewUserAttributeService,
|
||||||
)
|
)
|
||||||
|
|||||||
48
backend/migrations/018_user_attributes.sql
Normal file
48
backend/migrations/018_user_attributes.sql
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
-- Add user attribute definitions and values tables for custom user attributes.
|
||||||
|
|
||||||
|
-- User Attribute Definitions table (with soft delete support)
|
||||||
|
CREATE TABLE IF NOT EXISTS user_attribute_definitions (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
key VARCHAR(100) NOT NULL,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
description TEXT DEFAULT '',
|
||||||
|
type VARCHAR(20) NOT NULL,
|
||||||
|
options JSONB DEFAULT '[]'::jsonb,
|
||||||
|
required BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
validation JSONB DEFAULT '{}'::jsonb,
|
||||||
|
placeholder VARCHAR(255) DEFAULT '',
|
||||||
|
display_order INT NOT NULL DEFAULT 0,
|
||||||
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
deleted_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Partial unique index for key (only for non-deleted records)
|
||||||
|
-- Allows reusing keys after soft delete
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_attribute_definitions_key_unique
|
||||||
|
ON user_attribute_definitions(key) WHERE deleted_at IS NULL;
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_attribute_definitions_enabled
|
||||||
|
ON user_attribute_definitions(enabled);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_attribute_definitions_display_order
|
||||||
|
ON user_attribute_definitions(display_order);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_attribute_definitions_deleted_at
|
||||||
|
ON user_attribute_definitions(deleted_at);
|
||||||
|
|
||||||
|
-- User Attribute Values table (hard delete only, no deleted_at)
|
||||||
|
CREATE TABLE IF NOT EXISTS user_attribute_values (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
attribute_id BIGINT NOT NULL REFERENCES user_attribute_definitions(id) ON DELETE CASCADE,
|
||||||
|
value TEXT DEFAULT '',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
|
||||||
|
UNIQUE(user_id, attribute_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_attribute_values_user_id
|
||||||
|
ON user_attribute_values(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_attribute_values_attribute_id
|
||||||
|
ON user_attribute_values(attribute_id);
|
||||||
83
backend/migrations/019_migrate_wechat_to_attributes.sql
Normal file
83
backend/migrations/019_migrate_wechat_to_attributes.sql
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
-- Migration: Move wechat field from users table to user_attribute_values
|
||||||
|
-- This migration:
|
||||||
|
-- 1. Creates a "wechat" attribute definition
|
||||||
|
-- 2. Migrates existing wechat data to user_attribute_values
|
||||||
|
-- 3. Does NOT drop the wechat column (for rollback safety, can be done in a later migration)
|
||||||
|
|
||||||
|
-- +goose Up
|
||||||
|
-- +goose StatementBegin
|
||||||
|
|
||||||
|
-- Step 1: Insert wechat attribute definition if not exists
|
||||||
|
INSERT INTO user_attribute_definitions (key, name, description, type, options, required, validation, placeholder, display_order, enabled, created_at, updated_at)
|
||||||
|
SELECT 'wechat', '微信', '用户微信号', 'text', '[]'::jsonb, false, '{}'::jsonb, '请输入微信号', 0, true, NOW(), NOW()
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1 FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Step 2: Migrate existing wechat values to user_attribute_values
|
||||||
|
-- Only migrate non-empty values
|
||||||
|
INSERT INTO user_attribute_values (user_id, attribute_id, value, created_at, updated_at)
|
||||||
|
SELECT
|
||||||
|
u.id,
|
||||||
|
(SELECT id FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL LIMIT 1),
|
||||||
|
u.wechat,
|
||||||
|
NOW(),
|
||||||
|
NOW()
|
||||||
|
FROM users u
|
||||||
|
WHERE u.wechat IS NOT NULL
|
||||||
|
AND u.wechat != ''
|
||||||
|
AND u.deleted_at IS NULL
|
||||||
|
AND NOT EXISTS (
|
||||||
|
SELECT 1 FROM user_attribute_values uav
|
||||||
|
WHERE uav.user_id = u.id
|
||||||
|
AND uav.attribute_id = (SELECT id FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL LIMIT 1)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Step 3: Update display_order to ensure wechat appears first
|
||||||
|
UPDATE user_attribute_definitions
|
||||||
|
SET display_order = -1
|
||||||
|
WHERE key = 'wechat' AND deleted_at IS NULL;
|
||||||
|
|
||||||
|
-- Reorder all attributes starting from 0
|
||||||
|
WITH ordered AS (
|
||||||
|
SELECT id, ROW_NUMBER() OVER (ORDER BY display_order, id) - 1 as new_order
|
||||||
|
FROM user_attribute_definitions
|
||||||
|
WHERE deleted_at IS NULL
|
||||||
|
)
|
||||||
|
UPDATE user_attribute_definitions
|
||||||
|
SET display_order = ordered.new_order
|
||||||
|
FROM ordered
|
||||||
|
WHERE user_attribute_definitions.id = ordered.id;
|
||||||
|
|
||||||
|
-- Step 4: Drop the redundant wechat column from users table
|
||||||
|
ALTER TABLE users DROP COLUMN IF EXISTS wechat;
|
||||||
|
|
||||||
|
-- +goose StatementEnd
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
-- +goose StatementBegin
|
||||||
|
|
||||||
|
-- Restore wechat column
|
||||||
|
ALTER TABLE users ADD COLUMN IF NOT EXISTS wechat VARCHAR(100) DEFAULT '';
|
||||||
|
|
||||||
|
-- Copy attribute values back to users.wechat column
|
||||||
|
UPDATE users u
|
||||||
|
SET wechat = uav.value
|
||||||
|
FROM user_attribute_values uav
|
||||||
|
JOIN user_attribute_definitions uad ON uav.attribute_id = uad.id
|
||||||
|
WHERE uav.user_id = u.id
|
||||||
|
AND uad.key = 'wechat'
|
||||||
|
AND uad.deleted_at IS NULL;
|
||||||
|
|
||||||
|
-- Delete migrated attribute values
|
||||||
|
DELETE FROM user_attribute_values
|
||||||
|
WHERE attribute_id IN (
|
||||||
|
SELECT id FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Soft-delete the wechat attribute definition
|
||||||
|
UPDATE user_attribute_definitions
|
||||||
|
SET deleted_at = NOW()
|
||||||
|
WHERE key = 'wechat' AND deleted_at IS NULL;
|
||||||
|
|
||||||
|
-- +goose StatementEnd
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user