merge: 合并 upstream/main

This commit is contained in:
song
2026-01-02 10:22:29 +08:00
135 changed files with 26830 additions and 1442 deletions

1
.gitignore vendored
View File

@@ -77,6 +77,7 @@ temp/
*.temp
*.log
*.bak
.cache/
# ===================
# 构建产物

View File

@@ -87,9 +87,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
@@ -99,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
@@ -116,7 +117,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
@@ -127,10 +132,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache)
timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)

View File

@@ -25,6 +25,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
stdsql "database/sql"
@@ -55,6 +57,10 @@ type Client struct {
User *UserClient
// UserAllowedGroup is the client for interacting with the UserAllowedGroup builders.
UserAllowedGroup *UserAllowedGroupClient
// UserAttributeDefinition is the client for interacting with the UserAttributeDefinition builders.
UserAttributeDefinition *UserAttributeDefinitionClient
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
UserAttributeValue *UserAttributeValueClient
// UserSubscription is the client for interacting with the UserSubscription builders.
UserSubscription *UserSubscriptionClient
}
@@ -78,6 +84,8 @@ func (c *Client) init() {
c.UsageLog = NewUsageLogClient(c.config)
c.User = NewUserClient(c.config)
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
c.UserAttributeDefinition = NewUserAttributeDefinitionClient(c.config)
c.UserAttributeValue = NewUserAttributeValueClient(c.config)
c.UserSubscription = NewUserSubscriptionClient(c.config)
}
@@ -169,19 +177,21 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config
cfg.driver = tx
return &Tx{
ctx: ctx,
config: cfg,
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
ApiKey: NewApiKeyClient(cfg),
Group: NewGroupClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
ctx: ctx,
config: cfg,
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
ApiKey: NewApiKeyClient(cfg),
Group: NewGroupClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -199,19 +209,21 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
ctx: ctx,
config: cfg,
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
ApiKey: NewApiKeyClient(cfg),
Group: NewGroupClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
ctx: ctx,
config: cfg,
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
ApiKey: NewApiKeyClient(cfg),
Group: NewGroupClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -242,7 +254,8 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting,
c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription,
c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition,
c.UserAttributeValue, c.UserSubscription,
} {
n.Use(hooks...)
}
@@ -253,7 +266,8 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting,
c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription,
c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition,
c.UserAttributeValue, c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@@ -282,6 +296,10 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.User.mutate(ctx, m)
case *UserAllowedGroupMutation:
return c.UserAllowedGroup.mutate(ctx, m)
case *UserAttributeDefinitionMutation:
return c.UserAttributeDefinition.mutate(ctx, m)
case *UserAttributeValueMutation:
return c.UserAttributeValue.mutate(ctx, m)
case *UserSubscriptionMutation:
return c.UserSubscription.mutate(ctx, m)
default:
@@ -1916,6 +1934,22 @@ func (c *UserClient) QueryUsageLogs(_m *User) *UsageLogQuery {
return query
}
// QueryAttributeValues queries the attribute_values edge of a User.
func (c *UserClient) QueryAttributeValues(_m *User) *UserAttributeValueQuery {
query := (&UserAttributeValueClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, id),
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AttributeValuesTable, user.AttributeValuesColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query()
@@ -2075,6 +2109,322 @@ func (c *UserAllowedGroupClient) mutate(ctx context.Context, m *UserAllowedGroup
}
}
// UserAttributeDefinitionClient is a client for the UserAttributeDefinition schema.
type UserAttributeDefinitionClient struct {
config
}
// NewUserAttributeDefinitionClient returns a client for the UserAttributeDefinition from the given config.
func NewUserAttributeDefinitionClient(c config) *UserAttributeDefinitionClient {
return &UserAttributeDefinitionClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `userattributedefinition.Hooks(f(g(h())))`.
func (c *UserAttributeDefinitionClient) Use(hooks ...Hook) {
c.hooks.UserAttributeDefinition = append(c.hooks.UserAttributeDefinition, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `userattributedefinition.Intercept(f(g(h())))`.
func (c *UserAttributeDefinitionClient) Intercept(interceptors ...Interceptor) {
c.inters.UserAttributeDefinition = append(c.inters.UserAttributeDefinition, interceptors...)
}
// Create returns a builder for creating a UserAttributeDefinition entity.
func (c *UserAttributeDefinitionClient) Create() *UserAttributeDefinitionCreate {
mutation := newUserAttributeDefinitionMutation(c.config, OpCreate)
return &UserAttributeDefinitionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of UserAttributeDefinition entities.
func (c *UserAttributeDefinitionClient) CreateBulk(builders ...*UserAttributeDefinitionCreate) *UserAttributeDefinitionCreateBulk {
return &UserAttributeDefinitionCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *UserAttributeDefinitionClient) MapCreateBulk(slice any, setFunc func(*UserAttributeDefinitionCreate, int)) *UserAttributeDefinitionCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &UserAttributeDefinitionCreateBulk{err: fmt.Errorf("calling to UserAttributeDefinitionClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*UserAttributeDefinitionCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &UserAttributeDefinitionCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for UserAttributeDefinition.
func (c *UserAttributeDefinitionClient) Update() *UserAttributeDefinitionUpdate {
mutation := newUserAttributeDefinitionMutation(c.config, OpUpdate)
return &UserAttributeDefinitionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *UserAttributeDefinitionClient) UpdateOne(_m *UserAttributeDefinition) *UserAttributeDefinitionUpdateOne {
mutation := newUserAttributeDefinitionMutation(c.config, OpUpdateOne, withUserAttributeDefinition(_m))
return &UserAttributeDefinitionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *UserAttributeDefinitionClient) UpdateOneID(id int64) *UserAttributeDefinitionUpdateOne {
mutation := newUserAttributeDefinitionMutation(c.config, OpUpdateOne, withUserAttributeDefinitionID(id))
return &UserAttributeDefinitionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for UserAttributeDefinition.
func (c *UserAttributeDefinitionClient) Delete() *UserAttributeDefinitionDelete {
mutation := newUserAttributeDefinitionMutation(c.config, OpDelete)
return &UserAttributeDefinitionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *UserAttributeDefinitionClient) DeleteOne(_m *UserAttributeDefinition) *UserAttributeDefinitionDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *UserAttributeDefinitionClient) DeleteOneID(id int64) *UserAttributeDefinitionDeleteOne {
builder := c.Delete().Where(userattributedefinition.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &UserAttributeDefinitionDeleteOne{builder}
}
// Query returns a query builder for UserAttributeDefinition.
func (c *UserAttributeDefinitionClient) Query() *UserAttributeDefinitionQuery {
return &UserAttributeDefinitionQuery{
config: c.config,
ctx: &QueryContext{Type: TypeUserAttributeDefinition},
inters: c.Interceptors(),
}
}
// Get returns a UserAttributeDefinition entity by its id.
func (c *UserAttributeDefinitionClient) Get(ctx context.Context, id int64) (*UserAttributeDefinition, error) {
return c.Query().Where(userattributedefinition.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *UserAttributeDefinitionClient) GetX(ctx context.Context, id int64) *UserAttributeDefinition {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryValues queries the values edge of a UserAttributeDefinition.
func (c *UserAttributeDefinitionClient) QueryValues(_m *UserAttributeDefinition) *UserAttributeValueQuery {
query := (&UserAttributeValueClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(userattributedefinition.Table, userattributedefinition.FieldID, id),
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, userattributedefinition.ValuesTable, userattributedefinition.ValuesColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *UserAttributeDefinitionClient) Hooks() []Hook {
hooks := c.hooks.UserAttributeDefinition
return append(hooks[:len(hooks):len(hooks)], userattributedefinition.Hooks[:]...)
}
// Interceptors returns the client interceptors.
func (c *UserAttributeDefinitionClient) Interceptors() []Interceptor {
inters := c.inters.UserAttributeDefinition
return append(inters[:len(inters):len(inters)], userattributedefinition.Interceptors[:]...)
}
func (c *UserAttributeDefinitionClient) mutate(ctx context.Context, m *UserAttributeDefinitionMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&UserAttributeDefinitionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&UserAttributeDefinitionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&UserAttributeDefinitionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&UserAttributeDefinitionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown UserAttributeDefinition mutation op: %q", m.Op())
}
}
// UserAttributeValueClient is a client for the UserAttributeValue schema.
type UserAttributeValueClient struct {
config
}
// NewUserAttributeValueClient returns a client for the UserAttributeValue from the given config.
func NewUserAttributeValueClient(c config) *UserAttributeValueClient {
return &UserAttributeValueClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `userattributevalue.Hooks(f(g(h())))`.
func (c *UserAttributeValueClient) Use(hooks ...Hook) {
c.hooks.UserAttributeValue = append(c.hooks.UserAttributeValue, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `userattributevalue.Intercept(f(g(h())))`.
func (c *UserAttributeValueClient) Intercept(interceptors ...Interceptor) {
c.inters.UserAttributeValue = append(c.inters.UserAttributeValue, interceptors...)
}
// Create returns a builder for creating a UserAttributeValue entity.
func (c *UserAttributeValueClient) Create() *UserAttributeValueCreate {
mutation := newUserAttributeValueMutation(c.config, OpCreate)
return &UserAttributeValueCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of UserAttributeValue entities.
func (c *UserAttributeValueClient) CreateBulk(builders ...*UserAttributeValueCreate) *UserAttributeValueCreateBulk {
return &UserAttributeValueCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *UserAttributeValueClient) MapCreateBulk(slice any, setFunc func(*UserAttributeValueCreate, int)) *UserAttributeValueCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &UserAttributeValueCreateBulk{err: fmt.Errorf("calling to UserAttributeValueClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*UserAttributeValueCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &UserAttributeValueCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for UserAttributeValue.
func (c *UserAttributeValueClient) Update() *UserAttributeValueUpdate {
mutation := newUserAttributeValueMutation(c.config, OpUpdate)
return &UserAttributeValueUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *UserAttributeValueClient) UpdateOne(_m *UserAttributeValue) *UserAttributeValueUpdateOne {
mutation := newUserAttributeValueMutation(c.config, OpUpdateOne, withUserAttributeValue(_m))
return &UserAttributeValueUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *UserAttributeValueClient) UpdateOneID(id int64) *UserAttributeValueUpdateOne {
mutation := newUserAttributeValueMutation(c.config, OpUpdateOne, withUserAttributeValueID(id))
return &UserAttributeValueUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for UserAttributeValue.
func (c *UserAttributeValueClient) Delete() *UserAttributeValueDelete {
mutation := newUserAttributeValueMutation(c.config, OpDelete)
return &UserAttributeValueDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *UserAttributeValueClient) DeleteOne(_m *UserAttributeValue) *UserAttributeValueDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *UserAttributeValueClient) DeleteOneID(id int64) *UserAttributeValueDeleteOne {
builder := c.Delete().Where(userattributevalue.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &UserAttributeValueDeleteOne{builder}
}
// Query returns a query builder for UserAttributeValue.
func (c *UserAttributeValueClient) Query() *UserAttributeValueQuery {
return &UserAttributeValueQuery{
config: c.config,
ctx: &QueryContext{Type: TypeUserAttributeValue},
inters: c.Interceptors(),
}
}
// Get returns a UserAttributeValue entity by its id.
func (c *UserAttributeValueClient) Get(ctx context.Context, id int64) (*UserAttributeValue, error) {
return c.Query().Where(userattributevalue.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *UserAttributeValueClient) GetX(ctx context.Context, id int64) *UserAttributeValue {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryUser queries the user edge of a UserAttributeValue.
func (c *UserAttributeValueClient) QueryUser(_m *UserAttributeValue) *UserQuery {
query := (&UserClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, id),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.UserTable, userattributevalue.UserColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryDefinition queries the definition edge of a UserAttributeValue.
func (c *UserAttributeValueClient) QueryDefinition(_m *UserAttributeValue) *UserAttributeDefinitionQuery {
query := (&UserAttributeDefinitionClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, id),
sqlgraph.To(userattributedefinition.Table, userattributedefinition.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.DefinitionTable, userattributevalue.DefinitionColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *UserAttributeValueClient) Hooks() []Hook {
return c.hooks.UserAttributeValue
}
// Interceptors returns the client interceptors.
func (c *UserAttributeValueClient) Interceptors() []Interceptor {
return c.inters.UserAttributeValue
}
func (c *UserAttributeValueClient) mutate(ctx context.Context, m *UserAttributeValueMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&UserAttributeValueCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&UserAttributeValueUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&UserAttributeValueUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&UserAttributeValueDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown UserAttributeValue mutation op: %q", m.Op())
}
}
// UserSubscriptionClient is a client for the UserSubscription schema.
type UserSubscriptionClient struct {
config
@@ -2278,11 +2628,13 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
type (
hooks struct {
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog,
User, UserAllowedGroup, UserSubscription []ent.Hook
User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
UserSubscription []ent.Hook
}
inters struct {
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog,
User, UserAllowedGroup, UserSubscription []ent.Interceptor
User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
UserSubscription []ent.Interceptor
}
)

View File

@@ -22,6 +22,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -83,17 +85,19 @@ var (
func checkColumn(t, c string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
account.Table: account.ValidColumn,
accountgroup.Table: accountgroup.ValidColumn,
apikey.Table: apikey.ValidColumn,
group.Table: group.ValidColumn,
proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn,
setting.Table: setting.ValidColumn,
usagelog.Table: usagelog.ValidColumn,
user.Table: user.ValidColumn,
userallowedgroup.Table: userallowedgroup.ValidColumn,
usersubscription.Table: usersubscription.ValidColumn,
account.Table: account.ValidColumn,
accountgroup.Table: accountgroup.ValidColumn,
apikey.Table: apikey.ValidColumn,
group.Table: group.ValidColumn,
proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn,
setting.Table: setting.ValidColumn,
usagelog.Table: usagelog.ValidColumn,
user.Table: user.ValidColumn,
userallowedgroup.Table: userallowedgroup.ValidColumn,
userattributedefinition.Table: userattributedefinition.ValidColumn,
userattributevalue.Table: userattributevalue.ValidColumn,
usersubscription.Table: usersubscription.ValidColumn,
})
})
return columnCheck(t, c)

View File

@@ -129,6 +129,30 @@ func (f UserAllowedGroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAllowedGroupMutation", m)
}
// The UserAttributeDefinitionFunc type is an adapter to allow the use of ordinary
// function as UserAttributeDefinition mutator.
type UserAttributeDefinitionFunc func(context.Context, *ent.UserAttributeDefinitionMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f UserAttributeDefinitionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.UserAttributeDefinitionMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeDefinitionMutation", m)
}
// The UserAttributeValueFunc type is an adapter to allow the use of ordinary
// function as UserAttributeValue mutator.
type UserAttributeValueFunc func(context.Context, *ent.UserAttributeValueMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f UserAttributeValueFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.UserAttributeValueMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeValueMutation", m)
}
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary
// function as UserSubscription mutator.
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error)

View File

@@ -19,6 +19,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -348,6 +350,60 @@ func (f TraverseUserAllowedGroup) Traverse(ctx context.Context, q ent.Query) err
return fmt.Errorf("unexpected query type %T. expect *ent.UserAllowedGroupQuery", q)
}
// The UserAttributeDefinitionFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserAttributeDefinitionFunc func(context.Context, *ent.UserAttributeDefinitionQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f UserAttributeDefinitionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.UserAttributeDefinitionQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeDefinitionQuery", q)
}
// The TraverseUserAttributeDefinition type is an adapter to allow the use of ordinary function as Traverser.
type TraverseUserAttributeDefinition func(context.Context, *ent.UserAttributeDefinitionQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseUserAttributeDefinition) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseUserAttributeDefinition) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.UserAttributeDefinitionQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeDefinitionQuery", q)
}
// The UserAttributeValueFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserAttributeValueFunc func(context.Context, *ent.UserAttributeValueQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f UserAttributeValueFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.UserAttributeValueQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q)
}
// The TraverseUserAttributeValue type is an adapter to allow the use of ordinary function as Traverser.
type TraverseUserAttributeValue func(context.Context, *ent.UserAttributeValueQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseUserAttributeValue) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseUserAttributeValue) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.UserAttributeValueQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q)
}
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error)
@@ -398,6 +454,10 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.UserQuery, predicate.User, user.OrderOption]{typ: ent.TypeUser, tq: q}, nil
case *ent.UserAllowedGroupQuery:
return &query[*ent.UserAllowedGroupQuery, predicate.UserAllowedGroup, userallowedgroup.OrderOption]{typ: ent.TypeUserAllowedGroup, tq: q}, nil
case *ent.UserAttributeDefinitionQuery:
return &query[*ent.UserAttributeDefinitionQuery, predicate.UserAttributeDefinition, userattributedefinition.OrderOption]{typ: ent.TypeUserAttributeDefinition, tq: q}, nil
case *ent.UserAttributeValueQuery:
return &query[*ent.UserAttributeValueQuery, predicate.UserAttributeValue, userattributevalue.OrderOption]{typ: ent.TypeUserAttributeValue, tq: q}, nil
case *ent.UserSubscriptionQuery:
return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil
default:

View File

@@ -477,7 +477,6 @@ var (
{Name: "concurrency", Type: field.TypeInt, Default: 5},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
{Name: "wechat", Type: field.TypeString, Size: 100, Default: ""},
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
}
// UsersTable holds the schema information for the "users" table.
@@ -531,6 +530,92 @@ var (
},
},
}
// UserAttributeDefinitionsColumns holds the columns for the "user_attribute_definitions" table.
UserAttributeDefinitionsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "key", Type: field.TypeString, Size: 100},
{Name: "name", Type: field.TypeString, Size: 255},
{Name: "description", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "type", Type: field.TypeString, Size: 20},
{Name: "options", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "required", Type: field.TypeBool, Default: false},
{Name: "validation", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "placeholder", Type: field.TypeString, Size: 255, Default: ""},
{Name: "display_order", Type: field.TypeInt, Default: 0},
{Name: "enabled", Type: field.TypeBool, Default: true},
}
// UserAttributeDefinitionsTable holds the schema information for the "user_attribute_definitions" table.
UserAttributeDefinitionsTable = &schema.Table{
Name: "user_attribute_definitions",
Columns: UserAttributeDefinitionsColumns,
PrimaryKey: []*schema.Column{UserAttributeDefinitionsColumns[0]},
Indexes: []*schema.Index{
{
Name: "userattributedefinition_key",
Unique: false,
Columns: []*schema.Column{UserAttributeDefinitionsColumns[4]},
},
{
Name: "userattributedefinition_enabled",
Unique: false,
Columns: []*schema.Column{UserAttributeDefinitionsColumns[13]},
},
{
Name: "userattributedefinition_display_order",
Unique: false,
Columns: []*schema.Column{UserAttributeDefinitionsColumns[12]},
},
{
Name: "userattributedefinition_deleted_at",
Unique: false,
Columns: []*schema.Column{UserAttributeDefinitionsColumns[3]},
},
},
}
// UserAttributeValuesColumns holds the columns for the "user_attribute_values" table.
UserAttributeValuesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "value", Type: field.TypeString, Size: 2147483647, Default: ""},
{Name: "user_id", Type: field.TypeInt64},
{Name: "attribute_id", Type: field.TypeInt64},
}
// UserAttributeValuesTable holds the schema information for the "user_attribute_values" table.
UserAttributeValuesTable = &schema.Table{
Name: "user_attribute_values",
Columns: UserAttributeValuesColumns,
PrimaryKey: []*schema.Column{UserAttributeValuesColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "user_attribute_values_users_attribute_values",
Columns: []*schema.Column{UserAttributeValuesColumns[4]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "user_attribute_values_user_attribute_definitions_values",
Columns: []*schema.Column{UserAttributeValuesColumns[5]},
RefColumns: []*schema.Column{UserAttributeDefinitionsColumns[0]},
OnDelete: schema.NoAction,
},
},
Indexes: []*schema.Index{
{
Name: "userattributevalue_user_id_attribute_id",
Unique: true,
Columns: []*schema.Column{UserAttributeValuesColumns[4], UserAttributeValuesColumns[5]},
},
{
Name: "userattributevalue_attribute_id",
Unique: false,
Columns: []*schema.Column{UserAttributeValuesColumns[5]},
},
},
}
// UserSubscriptionsColumns holds the columns for the "user_subscriptions" table.
UserSubscriptionsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -627,6 +712,8 @@ var (
UsageLogsTable,
UsersTable,
UserAllowedGroupsTable,
UserAttributeDefinitionsTable,
UserAttributeValuesTable,
UserSubscriptionsTable,
}
)
@@ -676,6 +763,14 @@ func init() {
UserAllowedGroupsTable.Annotation = &entsql.Annotation{
Table: "user_allowed_groups",
}
UserAttributeDefinitionsTable.Annotation = &entsql.Annotation{
Table: "user_attribute_definitions",
}
UserAttributeValuesTable.ForeignKeys[0].RefTable = UsersTable
UserAttributeValuesTable.ForeignKeys[1].RefTable = UserAttributeDefinitionsTable
UserAttributeValuesTable.Annotation = &entsql.Annotation{
Table: "user_attribute_values",
}
UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable
UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable
UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable

File diff suppressed because it is too large Load Diff

View File

@@ -36,5 +36,11 @@ type User func(*sql.Selector)
// UserAllowedGroup is the predicate function for userallowedgroup builders.
type UserAllowedGroup func(*sql.Selector)
// UserAttributeDefinition is the predicate function for userattributedefinition builders.
type UserAttributeDefinition func(*sql.Selector)
// UserAttributeValue is the predicate function for userattributevalue builders.
type UserAttributeValue func(*sql.Selector)
// UserSubscription is the predicate function for usersubscription builders.
type UserSubscription func(*sql.Selector)

View File

@@ -16,6 +16,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -604,14 +606,8 @@ func init() {
user.DefaultUsername = userDescUsername.Default.(string)
// user.UsernameValidator is a validator for the "username" field. It is called by the builders before save.
user.UsernameValidator = userDescUsername.Validators[0].(func(string) error)
// userDescWechat is the schema descriptor for wechat field.
userDescWechat := userFields[7].Descriptor()
// user.DefaultWechat holds the default value on creation for the wechat field.
user.DefaultWechat = userDescWechat.Default.(string)
// user.WechatValidator is a validator for the "wechat" field. It is called by the builders before save.
user.WechatValidator = userDescWechat.Validators[0].(func(string) error)
// userDescNotes is the schema descriptor for notes field.
userDescNotes := userFields[8].Descriptor()
userDescNotes := userFields[7].Descriptor()
// user.DefaultNotes holds the default value on creation for the notes field.
user.DefaultNotes = userDescNotes.Default.(string)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
@@ -620,6 +616,128 @@ func init() {
userallowedgroupDescCreatedAt := userallowedgroupFields[2].Descriptor()
// userallowedgroup.DefaultCreatedAt holds the default value on creation for the created_at field.
userallowedgroup.DefaultCreatedAt = userallowedgroupDescCreatedAt.Default.(func() time.Time)
userattributedefinitionMixin := schema.UserAttributeDefinition{}.Mixin()
userattributedefinitionMixinHooks1 := userattributedefinitionMixin[1].Hooks()
userattributedefinition.Hooks[0] = userattributedefinitionMixinHooks1[0]
userattributedefinitionMixinInters1 := userattributedefinitionMixin[1].Interceptors()
userattributedefinition.Interceptors[0] = userattributedefinitionMixinInters1[0]
userattributedefinitionMixinFields0 := userattributedefinitionMixin[0].Fields()
_ = userattributedefinitionMixinFields0
userattributedefinitionFields := schema.UserAttributeDefinition{}.Fields()
_ = userattributedefinitionFields
// userattributedefinitionDescCreatedAt is the schema descriptor for created_at field.
userattributedefinitionDescCreatedAt := userattributedefinitionMixinFields0[0].Descriptor()
// userattributedefinition.DefaultCreatedAt holds the default value on creation for the created_at field.
userattributedefinition.DefaultCreatedAt = userattributedefinitionDescCreatedAt.Default.(func() time.Time)
// userattributedefinitionDescUpdatedAt is the schema descriptor for updated_at field.
userattributedefinitionDescUpdatedAt := userattributedefinitionMixinFields0[1].Descriptor()
// userattributedefinition.DefaultUpdatedAt holds the default value on creation for the updated_at field.
userattributedefinition.DefaultUpdatedAt = userattributedefinitionDescUpdatedAt.Default.(func() time.Time)
// userattributedefinition.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
userattributedefinition.UpdateDefaultUpdatedAt = userattributedefinitionDescUpdatedAt.UpdateDefault.(func() time.Time)
// userattributedefinitionDescKey is the schema descriptor for key field.
userattributedefinitionDescKey := userattributedefinitionFields[0].Descriptor()
// userattributedefinition.KeyValidator is a validator for the "key" field. It is called by the builders before save.
userattributedefinition.KeyValidator = func() func(string) error {
validators := userattributedefinitionDescKey.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(key string) error {
for _, fn := range fns {
if err := fn(key); err != nil {
return err
}
}
return nil
}
}()
// userattributedefinitionDescName is the schema descriptor for name field.
userattributedefinitionDescName := userattributedefinitionFields[1].Descriptor()
// userattributedefinition.NameValidator is a validator for the "name" field. It is called by the builders before save.
userattributedefinition.NameValidator = func() func(string) error {
validators := userattributedefinitionDescName.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(name string) error {
for _, fn := range fns {
if err := fn(name); err != nil {
return err
}
}
return nil
}
}()
// userattributedefinitionDescDescription is the schema descriptor for description field.
userattributedefinitionDescDescription := userattributedefinitionFields[2].Descriptor()
// userattributedefinition.DefaultDescription holds the default value on creation for the description field.
userattributedefinition.DefaultDescription = userattributedefinitionDescDescription.Default.(string)
// userattributedefinitionDescType is the schema descriptor for type field.
userattributedefinitionDescType := userattributedefinitionFields[3].Descriptor()
// userattributedefinition.TypeValidator is a validator for the "type" field. It is called by the builders before save.
userattributedefinition.TypeValidator = func() func(string) error {
validators := userattributedefinitionDescType.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(_type string) error {
for _, fn := range fns {
if err := fn(_type); err != nil {
return err
}
}
return nil
}
}()
// userattributedefinitionDescOptions is the schema descriptor for options field.
userattributedefinitionDescOptions := userattributedefinitionFields[4].Descriptor()
// userattributedefinition.DefaultOptions holds the default value on creation for the options field.
userattributedefinition.DefaultOptions = userattributedefinitionDescOptions.Default.([]map[string]interface{})
// userattributedefinitionDescRequired is the schema descriptor for required field.
userattributedefinitionDescRequired := userattributedefinitionFields[5].Descriptor()
// userattributedefinition.DefaultRequired holds the default value on creation for the required field.
userattributedefinition.DefaultRequired = userattributedefinitionDescRequired.Default.(bool)
// userattributedefinitionDescValidation is the schema descriptor for validation field.
userattributedefinitionDescValidation := userattributedefinitionFields[6].Descriptor()
// userattributedefinition.DefaultValidation holds the default value on creation for the validation field.
userattributedefinition.DefaultValidation = userattributedefinitionDescValidation.Default.(map[string]interface{})
// userattributedefinitionDescPlaceholder is the schema descriptor for placeholder field.
userattributedefinitionDescPlaceholder := userattributedefinitionFields[7].Descriptor()
// userattributedefinition.DefaultPlaceholder holds the default value on creation for the placeholder field.
userattributedefinition.DefaultPlaceholder = userattributedefinitionDescPlaceholder.Default.(string)
// userattributedefinition.PlaceholderValidator is a validator for the "placeholder" field. It is called by the builders before save.
userattributedefinition.PlaceholderValidator = userattributedefinitionDescPlaceholder.Validators[0].(func(string) error)
// userattributedefinitionDescDisplayOrder is the schema descriptor for display_order field.
userattributedefinitionDescDisplayOrder := userattributedefinitionFields[8].Descriptor()
// userattributedefinition.DefaultDisplayOrder holds the default value on creation for the display_order field.
userattributedefinition.DefaultDisplayOrder = userattributedefinitionDescDisplayOrder.Default.(int)
// userattributedefinitionDescEnabled is the schema descriptor for enabled field.
userattributedefinitionDescEnabled := userattributedefinitionFields[9].Descriptor()
// userattributedefinition.DefaultEnabled holds the default value on creation for the enabled field.
userattributedefinition.DefaultEnabled = userattributedefinitionDescEnabled.Default.(bool)
userattributevalueMixin := schema.UserAttributeValue{}.Mixin()
userattributevalueMixinFields0 := userattributevalueMixin[0].Fields()
_ = userattributevalueMixinFields0
userattributevalueFields := schema.UserAttributeValue{}.Fields()
_ = userattributevalueFields
// userattributevalueDescCreatedAt is the schema descriptor for created_at field.
userattributevalueDescCreatedAt := userattributevalueMixinFields0[0].Descriptor()
// userattributevalue.DefaultCreatedAt holds the default value on creation for the created_at field.
userattributevalue.DefaultCreatedAt = userattributevalueDescCreatedAt.Default.(func() time.Time)
// userattributevalueDescUpdatedAt is the schema descriptor for updated_at field.
userattributevalueDescUpdatedAt := userattributevalueMixinFields0[1].Descriptor()
// userattributevalue.DefaultUpdatedAt holds the default value on creation for the updated_at field.
userattributevalue.DefaultUpdatedAt = userattributevalueDescUpdatedAt.Default.(func() time.Time)
// userattributevalue.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
userattributevalue.UpdateDefaultUpdatedAt = userattributevalueDescUpdatedAt.UpdateDefault.(func() time.Time)
// userattributevalueDescValue is the schema descriptor for value field.
userattributevalueDescValue := userattributevalueFields[2].Descriptor()
// userattributevalue.DefaultValue holds the default value on creation for the value field.
userattributevalue.DefaultValue = userattributevalueDescValue.Default.(string)
usersubscriptionMixin := schema.UserSubscription{}.Mixin()
usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks()
usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0]

View File

@@ -57,9 +57,7 @@ func (User) Fields() []ent.Field {
field.String("username").
MaxLen(100).
Default(""),
field.String("wechat").
MaxLen(100).
Default(""),
// wechat field migrated to user_attribute_values (see migration 019)
field.String("notes").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default(""),
@@ -75,6 +73,7 @@ func (User) Edges() []ent.Edge {
edge.To("allowed_groups", Group.Type).
Through("user_allowed_groups", UserAllowedGroup.Type),
edge.To("usage_logs", UsageLog.Type),
edge.To("attribute_values", UserAttributeValue.Type),
}
}

View File

@@ -0,0 +1,109 @@
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// UserAttributeDefinition holds the schema definition for custom user attributes.
//
// This entity defines the metadata for user attributes, such as:
// - Attribute key (unique identifier like "company_name")
// - Display name shown in forms
// - Field type (text, number, select, etc.)
// - Validation rules
// - Whether the field is required or enabled
type UserAttributeDefinition struct {
ent.Schema
}
func (UserAttributeDefinition) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "user_attribute_definitions"},
}
}
func (UserAttributeDefinition) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
mixins.SoftDeleteMixin{},
}
}
func (UserAttributeDefinition) Fields() []ent.Field {
return []ent.Field{
// key: Unique identifier for the attribute (e.g., "company_name")
// Used for programmatic reference
field.String("key").
MaxLen(100).
NotEmpty(),
// name: Display name shown in forms (e.g., "Company Name")
field.String("name").
MaxLen(255).
NotEmpty(),
// description: Optional description/help text for the attribute
field.String("description").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default(""),
// type: Attribute type - text, textarea, number, email, url, date, select, multi_select
field.String("type").
MaxLen(20).
NotEmpty(),
// options: Select options for select/multi_select types (stored as JSONB)
// Format: [{"value": "xxx", "label": "XXX"}, ...]
field.JSON("options", []map[string]any{}).
Default([]map[string]any{}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// required: Whether this attribute is required when editing a user
field.Bool("required").
Default(false),
// validation: Validation rules for the attribute value (stored as JSONB)
// Format: {"min_length": 1, "max_length": 100, "min": 0, "max": 100, "pattern": "^[a-z]+$", "message": "..."}
field.JSON("validation", map[string]any{}).
Default(map[string]any{}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// placeholder: Placeholder text shown in input fields
field.String("placeholder").
MaxLen(255).
Default(""),
// display_order: Order in which attributes are displayed (lower = first)
field.Int("display_order").
Default(0),
// enabled: Whether this attribute is active and shown in forms
field.Bool("enabled").
Default(true),
}
}
func (UserAttributeDefinition) Edges() []ent.Edge {
return []ent.Edge{
// values: All user values for this attribute definition
edge.To("values", UserAttributeValue.Type),
}
}
func (UserAttributeDefinition) Indexes() []ent.Index {
return []ent.Index{
// Partial unique index on key (WHERE deleted_at IS NULL) via migration
index.Fields("key"),
index.Fields("enabled"),
index.Fields("display_order"),
index.Fields("deleted_at"),
}
}

View File

@@ -0,0 +1,74 @@
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// UserAttributeValue holds a user's value for a specific attribute.
//
// This entity stores the actual values that users have for each attribute definition.
// Values are stored as strings and converted to the appropriate type by the application.
type UserAttributeValue struct {
ent.Schema
}
func (UserAttributeValue) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "user_attribute_values"},
}
}
func (UserAttributeValue) Mixin() []ent.Mixin {
return []ent.Mixin{
// Only use TimeMixin, no soft delete - values are hard deleted
mixins.TimeMixin{},
}
}
func (UserAttributeValue) Fields() []ent.Field {
return []ent.Field{
// user_id: References the user this value belongs to
field.Int64("user_id"),
// attribute_id: References the attribute definition
field.Int64("attribute_id"),
// value: The actual value stored as a string
// For multi_select, this is a JSON array string
field.Text("value").
Default(""),
}
}
func (UserAttributeValue) Edges() []ent.Edge {
return []ent.Edge{
// user: The user who owns this attribute value
edge.From("user", User.Type).
Ref("attribute_values").
Field("user_id").
Required().
Unique(),
// definition: The attribute definition this value is for
edge.From("definition", UserAttributeDefinition.Type).
Ref("values").
Field("attribute_id").
Required().
Unique(),
}
}
func (UserAttributeValue) Indexes() []ent.Index {
return []ent.Index{
// Unique index on (user_id, attribute_id)
index.Fields("user_id", "attribute_id").Unique(),
index.Fields("attribute_id"),
}
}

View File

@@ -34,6 +34,10 @@ type Tx struct {
User *UserClient
// UserAllowedGroup is the client for interacting with the UserAllowedGroup builders.
UserAllowedGroup *UserAllowedGroupClient
// UserAttributeDefinition is the client for interacting with the UserAttributeDefinition builders.
UserAttributeDefinition *UserAttributeDefinitionClient
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
UserAttributeValue *UserAttributeValueClient
// UserSubscription is the client for interacting with the UserSubscription builders.
UserSubscription *UserSubscriptionClient
@@ -177,6 +181,8 @@ func (tx *Tx) init() {
tx.UsageLog = NewUsageLogClient(tx.config)
tx.User = NewUserClient(tx.config)
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
tx.UserAttributeDefinition = NewUserAttributeDefinitionClient(tx.config)
tx.UserAttributeValue = NewUserAttributeValueClient(tx.config)
tx.UserSubscription = NewUserSubscriptionClient(tx.config)
}

View File

@@ -37,8 +37,6 @@ type User struct {
Status string `json:"status,omitempty"`
// Username holds the value of the "username" field.
Username string `json:"username,omitempty"`
// Wechat holds the value of the "wechat" field.
Wechat string `json:"wechat,omitempty"`
// Notes holds the value of the "notes" field.
Notes string `json:"notes,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
@@ -61,11 +59,13 @@ type UserEdges struct {
AllowedGroups []*Group `json:"allowed_groups,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
// AttributeValues holds the value of the attribute_values edge.
AttributeValues []*UserAttributeValue `json:"attribute_values,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [7]bool
loadedTypes [8]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -122,10 +122,19 @@ func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
return nil, &NotLoadedError{edge: "usage_logs"}
}
// AttributeValuesOrErr returns the AttributeValues value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
if e.loadedTypes[6] {
return e.AttributeValues, nil
}
return nil, &NotLoadedError{edge: "attribute_values"}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[6] {
if e.loadedTypes[7] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -140,7 +149,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldWechat, user.FieldNotes:
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
values[i] = new(sql.NullTime)
@@ -226,12 +235,6 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Username = value.String
}
case user.FieldWechat:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field wechat", values[i])
} else if value.Valid {
_m.Wechat = value.String
}
case user.FieldNotes:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field notes", values[i])
@@ -281,6 +284,11 @@ func (_m *User) QueryUsageLogs() *UsageLogQuery {
return NewUserClient(_m.config).QueryUsageLogs(_m)
}
// QueryAttributeValues queries the "attribute_values" edge of the User entity.
func (_m *User) QueryAttributeValues() *UserAttributeValueQuery {
return NewUserClient(_m.config).QueryAttributeValues(_m)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
@@ -341,9 +349,6 @@ func (_m *User) String() string {
builder.WriteString("username=")
builder.WriteString(_m.Username)
builder.WriteString(", ")
builder.WriteString("wechat=")
builder.WriteString(_m.Wechat)
builder.WriteString(", ")
builder.WriteString("notes=")
builder.WriteString(_m.Notes)
builder.WriteByte(')')

View File

@@ -35,8 +35,6 @@ const (
FieldStatus = "status"
// FieldUsername holds the string denoting the username field in the database.
FieldUsername = "username"
// FieldWechat holds the string denoting the wechat field in the database.
FieldWechat = "wechat"
// FieldNotes holds the string denoting the notes field in the database.
FieldNotes = "notes"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
@@ -51,6 +49,8 @@ const (
EdgeAllowedGroups = "allowed_groups"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs = "usage_logs"
// EdgeAttributeValues holds the string denoting the attribute_values edge name in mutations.
EdgeAttributeValues = "attribute_values"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
@@ -95,6 +95,13 @@ const (
UsageLogsInverseTable = "usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn = "user_id"
// AttributeValuesTable is the table that holds the attribute_values relation/edge.
AttributeValuesTable = "user_attribute_values"
// AttributeValuesInverseTable is the table name for the UserAttributeValue entity.
// It exists in this package in order to avoid circular dependency with the "userattributevalue" package.
AttributeValuesInverseTable = "user_attribute_values"
// AttributeValuesColumn is the table column denoting the attribute_values relation/edge.
AttributeValuesColumn = "user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@@ -117,7 +124,6 @@ var Columns = []string{
FieldConcurrency,
FieldStatus,
FieldUsername,
FieldWechat,
FieldNotes,
}
@@ -171,10 +177,6 @@ var (
DefaultUsername string
// UsernameValidator is a validator for the "username" field. It is called by the builders before save.
UsernameValidator func(string) error
// DefaultWechat holds the default value on creation for the "wechat" field.
DefaultWechat string
// WechatValidator is a validator for the "wechat" field. It is called by the builders before save.
WechatValidator func(string) error
// DefaultNotes holds the default value on creation for the "notes" field.
DefaultNotes string
)
@@ -237,11 +239,6 @@ func ByUsername(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUsername, opts...).ToFunc()
}
// ByWechat orders the results by the wechat field.
func ByWechat(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWechat, opts...).ToFunc()
}
// ByNotes orders the results by the notes field.
func ByNotes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotes, opts...).ToFunc()
@@ -331,6 +328,20 @@ func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
// ByAttributeValuesCount orders the results by attribute_values count.
func ByAttributeValuesCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newAttributeValuesStep(), opts...)
}
}
// ByAttributeValues orders the results by attribute_values terms.
func ByAttributeValues(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAttributeValuesStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -386,6 +397,13 @@ func newUsageLogsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
}
func newAttributeValuesStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AttributeValuesInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AttributeValuesTable, AttributeValuesColumn),
)
}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),

View File

@@ -105,11 +105,6 @@ func Username(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldUsername, v))
}
// Wechat applies equality check predicate on the "wechat" field. It's identical to WechatEQ.
func Wechat(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldWechat, v))
}
// Notes applies equality check predicate on the "notes" field. It's identical to NotesEQ.
func Notes(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldNotes, v))
@@ -650,71 +645,6 @@ func UsernameContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldUsername, v))
}
// WechatEQ applies the EQ predicate on the "wechat" field.
func WechatEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldWechat, v))
}
// WechatNEQ applies the NEQ predicate on the "wechat" field.
func WechatNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldWechat, v))
}
// WechatIn applies the In predicate on the "wechat" field.
func WechatIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldWechat, vs...))
}
// WechatNotIn applies the NotIn predicate on the "wechat" field.
func WechatNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldWechat, vs...))
}
// WechatGT applies the GT predicate on the "wechat" field.
func WechatGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldWechat, v))
}
// WechatGTE applies the GTE predicate on the "wechat" field.
func WechatGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldWechat, v))
}
// WechatLT applies the LT predicate on the "wechat" field.
func WechatLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldWechat, v))
}
// WechatLTE applies the LTE predicate on the "wechat" field.
func WechatLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldWechat, v))
}
// WechatContains applies the Contains predicate on the "wechat" field.
func WechatContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldWechat, v))
}
// WechatHasPrefix applies the HasPrefix predicate on the "wechat" field.
func WechatHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldWechat, v))
}
// WechatHasSuffix applies the HasSuffix predicate on the "wechat" field.
func WechatHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldWechat, v))
}
// WechatEqualFold applies the EqualFold predicate on the "wechat" field.
func WechatEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldWechat, v))
}
// WechatContainsFold applies the ContainsFold predicate on the "wechat" field.
func WechatContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldWechat, v))
}
// NotesEQ applies the EQ predicate on the "notes" field.
func NotesEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldNotes, v))
@@ -918,6 +848,29 @@ func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.User {
})
}
// HasAttributeValues applies the HasEdge predicate on the "attribute_values" edge.
func HasAttributeValues() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AttributeValuesTable, AttributeValuesColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAttributeValuesWith applies the HasEdge predicate on the "attribute_values" edge with a given conditions (other predicates).
func HasAttributeValuesWith(preds ...predicate.UserAttributeValue) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newAttributeValuesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -151,20 +152,6 @@ func (_c *UserCreate) SetNillableUsername(v *string) *UserCreate {
return _c
}
// SetWechat sets the "wechat" field.
func (_c *UserCreate) SetWechat(v string) *UserCreate {
_c.mutation.SetWechat(v)
return _c
}
// SetNillableWechat sets the "wechat" field if the given value is not nil.
func (_c *UserCreate) SetNillableWechat(v *string) *UserCreate {
if v != nil {
_c.SetWechat(*v)
}
return _c
}
// SetNotes sets the "notes" field.
func (_c *UserCreate) SetNotes(v string) *UserCreate {
_c.mutation.SetNotes(v)
@@ -269,6 +256,21 @@ func (_c *UserCreate) AddUsageLogs(v ...*UsageLog) *UserCreate {
return _c.AddUsageLogIDs(ids...)
}
// AddAttributeValueIDs adds the "attribute_values" edge to the UserAttributeValue entity by IDs.
func (_c *UserCreate) AddAttributeValueIDs(ids ...int64) *UserCreate {
_c.mutation.AddAttributeValueIDs(ids...)
return _c
}
// AddAttributeValues adds the "attribute_values" edges to the UserAttributeValue entity.
func (_c *UserCreate) AddAttributeValues(v ...*UserAttributeValue) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddAttributeValueIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
@@ -340,10 +342,6 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultUsername
_c.mutation.SetUsername(v)
}
if _, ok := _c.mutation.Wechat(); !ok {
v := user.DefaultWechat
_c.mutation.SetWechat(v)
}
if _, ok := _c.mutation.Notes(); !ok {
v := user.DefaultNotes
_c.mutation.SetNotes(v)
@@ -405,14 +403,6 @@ func (_c *UserCreate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if _, ok := _c.mutation.Wechat(); !ok {
return &ValidationError{Name: "wechat", err: errors.New(`ent: missing required field "User.wechat"`)}
}
if v, ok := _c.mutation.Wechat(); ok {
if err := user.WechatValidator(v); err != nil {
return &ValidationError{Name: "wechat", err: fmt.Errorf(`ent: validator failed for field "User.wechat": %w`, err)}
}
}
if _, ok := _c.mutation.Notes(); !ok {
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
}
@@ -483,10 +473,6 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldUsername, field.TypeString, value)
_node.Username = value
}
if value, ok := _c.mutation.Wechat(); ok {
_spec.SetField(user.FieldWechat, field.TypeString, value)
_node.Wechat = value
}
if value, ok := _c.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
_node.Notes = value
@@ -591,6 +577,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.AttributeValuesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
@@ -769,18 +771,6 @@ func (u *UserUpsert) UpdateUsername() *UserUpsert {
return u
}
// SetWechat sets the "wechat" field.
func (u *UserUpsert) SetWechat(v string) *UserUpsert {
u.Set(user.FieldWechat, v)
return u
}
// UpdateWechat sets the "wechat" field to the value that was provided on create.
func (u *UserUpsert) UpdateWechat() *UserUpsert {
u.SetExcluded(user.FieldWechat)
return u
}
// SetNotes sets the "notes" field.
func (u *UserUpsert) SetNotes(v string) *UserUpsert {
u.Set(user.FieldNotes, v)
@@ -985,20 +975,6 @@ func (u *UserUpsertOne) UpdateUsername() *UserUpsertOne {
})
}
// SetWechat sets the "wechat" field.
func (u *UserUpsertOne) SetWechat(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetWechat(v)
})
}
// UpdateWechat sets the "wechat" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateWechat() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateWechat()
})
}
// SetNotes sets the "notes" field.
func (u *UserUpsertOne) SetNotes(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
@@ -1371,20 +1347,6 @@ func (u *UserUpsertBulk) UpdateUsername() *UserUpsertBulk {
})
}
// SetWechat sets the "wechat" field.
func (u *UserUpsertBulk) SetWechat(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetWechat(v)
})
}
// UpdateWechat sets the "wechat" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateWechat() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateWechat()
})
}
// SetNotes sets the "notes" field.
func (u *UserUpsertBulk) SetNotes(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {

View File

@@ -19,6 +19,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -35,6 +36,7 @@ type UserQuery struct {
withAssignedSubscriptions *UserSubscriptionQuery
withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery
withAttributeValues *UserAttributeValueQuery
withUserAllowedGroups *UserAllowedGroupQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
@@ -204,6 +206,28 @@ func (_q *UserQuery) QueryUsageLogs() *UsageLogQuery {
return query
}
// QueryAttributeValues chains the current query on the "attribute_values" edge.
func (_q *UserQuery) QueryAttributeValues() *UserAttributeValueQuery {
query := (&UserAttributeValueClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AttributeValuesTable, user.AttributeValuesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
@@ -424,6 +448,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(),
withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withAttributeValues: _q.withAttributeValues.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@@ -497,6 +522,17 @@ func (_q *UserQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserQuery {
return _q
}
// WithAttributeValues tells the query-builder to eager-load the nodes that are connected to
// the "attribute_values" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAttributeValues(opts ...func(*UserAttributeValueQuery)) *UserQuery {
query := (&UserAttributeValueClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAttributeValues = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@@ -586,13 +622,14 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [7]bool{
loadedTypes = [8]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
_q.withAssignedSubscriptions != nil,
_q.withAllowedGroups != nil,
_q.withUsageLogs != nil,
_q.withAttributeValues != nil,
_q.withUserAllowedGroups != nil,
}
)
@@ -658,6 +695,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withAttributeValues; query != nil {
if err := _q.loadAttributeValues(ctx, query, nodes,
func(n *User) { n.Edges.AttributeValues = []*UserAttributeValue{} },
func(n *User, e *UserAttributeValue) { n.Edges.AttributeValues = append(n.Edges.AttributeValues, e) }); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@@ -885,6 +929,36 @@ func (_q *UserQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, no
}
return nil
}
func (_q *UserQuery) loadAttributeValues(ctx context.Context, query *UserAttributeValueQuery, nodes []*User, init func(*User), assign func(*User, *UserAttributeValue)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(userattributevalue.FieldUserID)
}
query.Where(predicate.UserAttributeValue(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.AttributeValuesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)

View File

@@ -17,6 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -171,20 +172,6 @@ func (_u *UserUpdate) SetNillableUsername(v *string) *UserUpdate {
return _u
}
// SetWechat sets the "wechat" field.
func (_u *UserUpdate) SetWechat(v string) *UserUpdate {
_u.mutation.SetWechat(v)
return _u
}
// SetNillableWechat sets the "wechat" field if the given value is not nil.
func (_u *UserUpdate) SetNillableWechat(v *string) *UserUpdate {
if v != nil {
_u.SetWechat(*v)
}
return _u
}
// SetNotes sets the "notes" field.
func (_u *UserUpdate) SetNotes(v string) *UserUpdate {
_u.mutation.SetNotes(v)
@@ -289,6 +276,21 @@ func (_u *UserUpdate) AddUsageLogs(v ...*UsageLog) *UserUpdate {
return _u.AddUsageLogIDs(ids...)
}
// AddAttributeValueIDs adds the "attribute_values" edge to the UserAttributeValue entity by IDs.
func (_u *UserUpdate) AddAttributeValueIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAttributeValueIDs(ids...)
return _u
}
// AddAttributeValues adds the "attribute_values" edges to the UserAttributeValue entity.
func (_u *UserUpdate) AddAttributeValues(v ...*UserAttributeValue) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAttributeValueIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
@@ -420,6 +422,27 @@ func (_u *UserUpdate) RemoveUsageLogs(v ...*UsageLog) *UserUpdate {
return _u.RemoveUsageLogIDs(ids...)
}
// ClearAttributeValues clears all "attribute_values" edges to the UserAttributeValue entity.
func (_u *UserUpdate) ClearAttributeValues() *UserUpdate {
_u.mutation.ClearAttributeValues()
return _u
}
// RemoveAttributeValueIDs removes the "attribute_values" edge to UserAttributeValue entities by IDs.
func (_u *UserUpdate) RemoveAttributeValueIDs(ids ...int64) *UserUpdate {
_u.mutation.RemoveAttributeValueIDs(ids...)
return _u
}
// RemoveAttributeValues removes "attribute_values" edges to UserAttributeValue entities.
func (_u *UserUpdate) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAttributeValueIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@@ -489,11 +512,6 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := _u.mutation.Wechat(); ok {
if err := user.WechatValidator(v); err != nil {
return &ValidationError{Name: "wechat", err: fmt.Errorf(`ent: validator failed for field "User.wechat": %w`, err)}
}
}
return nil
}
@@ -545,9 +563,6 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
}
if value, ok := _u.mutation.Wechat(); ok {
_spec.SetField(user.FieldWechat, field.TypeString, value)
}
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
@@ -833,6 +848,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AttributeValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAttributeValuesIDs(); len(nodes) > 0 && !_u.mutation.AttributeValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AttributeValuesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@@ -991,20 +1051,6 @@ func (_u *UserUpdateOne) SetNillableUsername(v *string) *UserUpdateOne {
return _u
}
// SetWechat sets the "wechat" field.
func (_u *UserUpdateOne) SetWechat(v string) *UserUpdateOne {
_u.mutation.SetWechat(v)
return _u
}
// SetNillableWechat sets the "wechat" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableWechat(v *string) *UserUpdateOne {
if v != nil {
_u.SetWechat(*v)
}
return _u
}
// SetNotes sets the "notes" field.
func (_u *UserUpdateOne) SetNotes(v string) *UserUpdateOne {
_u.mutation.SetNotes(v)
@@ -1109,6 +1155,21 @@ func (_u *UserUpdateOne) AddUsageLogs(v ...*UsageLog) *UserUpdateOne {
return _u.AddUsageLogIDs(ids...)
}
// AddAttributeValueIDs adds the "attribute_values" edge to the UserAttributeValue entity by IDs.
func (_u *UserUpdateOne) AddAttributeValueIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAttributeValueIDs(ids...)
return _u
}
// AddAttributeValues adds the "attribute_values" edges to the UserAttributeValue entity.
func (_u *UserUpdateOne) AddAttributeValues(v ...*UserAttributeValue) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAttributeValueIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
@@ -1240,6 +1301,27 @@ func (_u *UserUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserUpdateOne {
return _u.RemoveUsageLogIDs(ids...)
}
// ClearAttributeValues clears all "attribute_values" edges to the UserAttributeValue entity.
func (_u *UserUpdateOne) ClearAttributeValues() *UserUpdateOne {
_u.mutation.ClearAttributeValues()
return _u
}
// RemoveAttributeValueIDs removes the "attribute_values" edge to UserAttributeValue entities by IDs.
func (_u *UserUpdateOne) RemoveAttributeValueIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemoveAttributeValueIDs(ids...)
return _u
}
// RemoveAttributeValues removes "attribute_values" edges to UserAttributeValue entities.
func (_u *UserUpdateOne) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAttributeValueIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
@@ -1322,11 +1404,6 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := _u.mutation.Wechat(); ok {
if err := user.WechatValidator(v); err != nil {
return &ValidationError{Name: "wechat", err: fmt.Errorf(`ent: validator failed for field "User.wechat": %w`, err)}
}
}
return nil
}
@@ -1395,9 +1472,6 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
}
if value, ok := _u.mutation.Wechat(); ok {
_spec.SetField(user.FieldWechat, field.TypeString, value)
}
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
@@ -1683,6 +1757,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AttributeValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAttributeValuesIDs(); len(nodes) > 0 && !_u.mutation.AttributeValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AttributeValuesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@@ -0,0 +1,276 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
)
// UserAttributeDefinition is the model entity for the UserAttributeDefinition schema.
type UserAttributeDefinition struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// DeletedAt holds the value of the "deleted_at" field.
DeletedAt *time.Time `json:"deleted_at,omitempty"`
// Key holds the value of the "key" field.
Key string `json:"key,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
// Description holds the value of the "description" field.
Description string `json:"description,omitempty"`
// Type holds the value of the "type" field.
Type string `json:"type,omitempty"`
// Options holds the value of the "options" field.
Options []map[string]interface{} `json:"options,omitempty"`
// Required holds the value of the "required" field.
Required bool `json:"required,omitempty"`
// Validation holds the value of the "validation" field.
Validation map[string]interface{} `json:"validation,omitempty"`
// Placeholder holds the value of the "placeholder" field.
Placeholder string `json:"placeholder,omitempty"`
// DisplayOrder holds the value of the "display_order" field.
DisplayOrder int `json:"display_order,omitempty"`
// Enabled holds the value of the "enabled" field.
Enabled bool `json:"enabled,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserAttributeDefinitionQuery when eager-loading is set.
Edges UserAttributeDefinitionEdges `json:"edges"`
selectValues sql.SelectValues
}
// UserAttributeDefinitionEdges holds the relations/edges for other nodes in the graph.
type UserAttributeDefinitionEdges struct {
// Values holds the value of the values edge.
Values []*UserAttributeValue `json:"values,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [1]bool
}
// ValuesOrErr returns the Values value or an error if the edge
// was not loaded in eager-loading.
func (e UserAttributeDefinitionEdges) ValuesOrErr() ([]*UserAttributeValue, error) {
if e.loadedTypes[0] {
return e.Values, nil
}
return nil, &NotLoadedError{edge: "values"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*UserAttributeDefinition) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case userattributedefinition.FieldOptions, userattributedefinition.FieldValidation:
values[i] = new([]byte)
case userattributedefinition.FieldRequired, userattributedefinition.FieldEnabled:
values[i] = new(sql.NullBool)
case userattributedefinition.FieldID, userattributedefinition.FieldDisplayOrder:
values[i] = new(sql.NullInt64)
case userattributedefinition.FieldKey, userattributedefinition.FieldName, userattributedefinition.FieldDescription, userattributedefinition.FieldType, userattributedefinition.FieldPlaceholder:
values[i] = new(sql.NullString)
case userattributedefinition.FieldCreatedAt, userattributedefinition.FieldUpdatedAt, userattributedefinition.FieldDeletedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the UserAttributeDefinition fields.
func (_m *UserAttributeDefinition) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case userattributedefinition.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case userattributedefinition.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case userattributedefinition.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case userattributedefinition.FieldDeletedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
} else if value.Valid {
_m.DeletedAt = new(time.Time)
*_m.DeletedAt = value.Time
}
case userattributedefinition.FieldKey:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field key", values[i])
} else if value.Valid {
_m.Key = value.String
}
case userattributedefinition.FieldName:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field name", values[i])
} else if value.Valid {
_m.Name = value.String
}
case userattributedefinition.FieldDescription:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field description", values[i])
} else if value.Valid {
_m.Description = value.String
}
case userattributedefinition.FieldType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field type", values[i])
} else if value.Valid {
_m.Type = value.String
}
case userattributedefinition.FieldOptions:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field options", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Options); err != nil {
return fmt.Errorf("unmarshal field options: %w", err)
}
}
case userattributedefinition.FieldRequired:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field required", values[i])
} else if value.Valid {
_m.Required = value.Bool
}
case userattributedefinition.FieldValidation:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field validation", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Validation); err != nil {
return fmt.Errorf("unmarshal field validation: %w", err)
}
}
case userattributedefinition.FieldPlaceholder:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field placeholder", values[i])
} else if value.Valid {
_m.Placeholder = value.String
}
case userattributedefinition.FieldDisplayOrder:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field display_order", values[i])
} else if value.Valid {
_m.DisplayOrder = int(value.Int64)
}
case userattributedefinition.FieldEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field enabled", values[i])
} else if value.Valid {
_m.Enabled = value.Bool
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the UserAttributeDefinition.
// This includes values selected through modifiers, order, etc.
func (_m *UserAttributeDefinition) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryValues queries the "values" edge of the UserAttributeDefinition entity.
func (_m *UserAttributeDefinition) QueryValues() *UserAttributeValueQuery {
return NewUserAttributeDefinitionClient(_m.config).QueryValues(_m)
}
// Update returns a builder for updating this UserAttributeDefinition.
// Note that you need to call UserAttributeDefinition.Unwrap() before calling this method if this UserAttributeDefinition
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *UserAttributeDefinition) Update() *UserAttributeDefinitionUpdateOne {
return NewUserAttributeDefinitionClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the UserAttributeDefinition entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *UserAttributeDefinition) Unwrap() *UserAttributeDefinition {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: UserAttributeDefinition is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *UserAttributeDefinition) String() string {
var builder strings.Builder
builder.WriteString("UserAttributeDefinition(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
if v := _m.DeletedAt; v != nil {
builder.WriteString("deleted_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("key=")
builder.WriteString(_m.Key)
builder.WriteString(", ")
builder.WriteString("name=")
builder.WriteString(_m.Name)
builder.WriteString(", ")
builder.WriteString("description=")
builder.WriteString(_m.Description)
builder.WriteString(", ")
builder.WriteString("type=")
builder.WriteString(_m.Type)
builder.WriteString(", ")
builder.WriteString("options=")
builder.WriteString(fmt.Sprintf("%v", _m.Options))
builder.WriteString(", ")
builder.WriteString("required=")
builder.WriteString(fmt.Sprintf("%v", _m.Required))
builder.WriteString(", ")
builder.WriteString("validation=")
builder.WriteString(fmt.Sprintf("%v", _m.Validation))
builder.WriteString(", ")
builder.WriteString("placeholder=")
builder.WriteString(_m.Placeholder)
builder.WriteString(", ")
builder.WriteString("display_order=")
builder.WriteString(fmt.Sprintf("%v", _m.DisplayOrder))
builder.WriteString(", ")
builder.WriteString("enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.Enabled))
builder.WriteByte(')')
return builder.String()
}
// UserAttributeDefinitions is a parsable slice of UserAttributeDefinition.
type UserAttributeDefinitions []*UserAttributeDefinition

View File

@@ -0,0 +1,205 @@
// Code generated by ent, DO NOT EDIT.
package userattributedefinition
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the userattributedefinition type in the database.
Label = "user_attribute_definition"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
FieldDeletedAt = "deleted_at"
// FieldKey holds the string denoting the key field in the database.
FieldKey = "key"
// FieldName holds the string denoting the name field in the database.
FieldName = "name"
// FieldDescription holds the string denoting the description field in the database.
FieldDescription = "description"
// FieldType holds the string denoting the type field in the database.
FieldType = "type"
// FieldOptions holds the string denoting the options field in the database.
FieldOptions = "options"
// FieldRequired holds the string denoting the required field in the database.
FieldRequired = "required"
// FieldValidation holds the string denoting the validation field in the database.
FieldValidation = "validation"
// FieldPlaceholder holds the string denoting the placeholder field in the database.
FieldPlaceholder = "placeholder"
// FieldDisplayOrder holds the string denoting the display_order field in the database.
FieldDisplayOrder = "display_order"
// FieldEnabled holds the string denoting the enabled field in the database.
FieldEnabled = "enabled"
// EdgeValues holds the string denoting the values edge name in mutations.
EdgeValues = "values"
// Table holds the table name of the userattributedefinition in the database.
Table = "user_attribute_definitions"
// ValuesTable is the table that holds the values relation/edge.
ValuesTable = "user_attribute_values"
// ValuesInverseTable is the table name for the UserAttributeValue entity.
// It exists in this package in order to avoid circular dependency with the "userattributevalue" package.
ValuesInverseTable = "user_attribute_values"
// ValuesColumn is the table column denoting the values relation/edge.
ValuesColumn = "attribute_id"
)
// Columns holds all SQL columns for userattributedefinition fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldDeletedAt,
FieldKey,
FieldName,
FieldDescription,
FieldType,
FieldOptions,
FieldRequired,
FieldValidation,
FieldPlaceholder,
FieldDisplayOrder,
FieldEnabled,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
// Note that the variables below are initialized by the runtime
// package on the initialization of the application. Therefore,
// it should be imported in the main as follows:
//
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var (
Hooks [1]ent.Hook
Interceptors [1]ent.Interceptor
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// KeyValidator is a validator for the "key" field. It is called by the builders before save.
KeyValidator func(string) error
// NameValidator is a validator for the "name" field. It is called by the builders before save.
NameValidator func(string) error
// DefaultDescription holds the default value on creation for the "description" field.
DefaultDescription string
// TypeValidator is a validator for the "type" field. It is called by the builders before save.
TypeValidator func(string) error
// DefaultOptions holds the default value on creation for the "options" field.
DefaultOptions []map[string]interface{}
// DefaultRequired holds the default value on creation for the "required" field.
DefaultRequired bool
// DefaultValidation holds the default value on creation for the "validation" field.
DefaultValidation map[string]interface{}
// DefaultPlaceholder holds the default value on creation for the "placeholder" field.
DefaultPlaceholder string
// PlaceholderValidator is a validator for the "placeholder" field. It is called by the builders before save.
PlaceholderValidator func(string) error
// DefaultDisplayOrder holds the default value on creation for the "display_order" field.
DefaultDisplayOrder int
// DefaultEnabled holds the default value on creation for the "enabled" field.
DefaultEnabled bool
)
// OrderOption defines the ordering options for the UserAttributeDefinition queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByDeletedAt orders the results by the deleted_at field.
func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
}
// ByKey orders the results by the key field.
func ByKey(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldKey, opts...).ToFunc()
}
// ByName orders the results by the name field.
func ByName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldName, opts...).ToFunc()
}
// ByDescription orders the results by the description field.
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDescription, opts...).ToFunc()
}
// ByType orders the results by the type field.
func ByType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldType, opts...).ToFunc()
}
// ByRequired orders the results by the required field.
func ByRequired(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRequired, opts...).ToFunc()
}
// ByPlaceholder orders the results by the placeholder field.
func ByPlaceholder(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPlaceholder, opts...).ToFunc()
}
// ByDisplayOrder orders the results by the display_order field.
func ByDisplayOrder(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDisplayOrder, opts...).ToFunc()
}
// ByEnabled orders the results by the enabled field.
func ByEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEnabled, opts...).ToFunc()
}
// ByValuesCount orders the results by values count.
func ByValuesCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newValuesStep(), opts...)
}
}
// ByValues orders the results by values terms.
func ByValues(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newValuesStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newValuesStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(ValuesInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, ValuesTable, ValuesColumn),
)
}

View File

@@ -0,0 +1,664 @@
// Code generated by ent, DO NOT EDIT.
package userattributedefinition
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldUpdatedAt, v))
}
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
func DeletedAt(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDeletedAt, v))
}
// Key applies equality check predicate on the "key" field. It's identical to KeyEQ.
func Key(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldKey, v))
}
// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
func Name(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldName, v))
}
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
func Description(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDescription, v))
}
// Type applies equality check predicate on the "type" field. It's identical to TypeEQ.
func Type(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldType, v))
}
// Required applies equality check predicate on the "required" field. It's identical to RequiredEQ.
func Required(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldRequired, v))
}
// Placeholder applies equality check predicate on the "placeholder" field. It's identical to PlaceholderEQ.
func Placeholder(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldPlaceholder, v))
}
// DisplayOrder applies equality check predicate on the "display_order" field. It's identical to DisplayOrderEQ.
func DisplayOrder(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDisplayOrder, v))
}
// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ.
func Enabled(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldEnabled, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldUpdatedAt, v))
}
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
func DeletedAtEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDeletedAt, v))
}
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
func DeletedAtNEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldDeletedAt, v))
}
// DeletedAtIn applies the In predicate on the "deleted_at" field.
func DeletedAtIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldDeletedAt, vs...))
}
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
func DeletedAtNotIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldDeletedAt, vs...))
}
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
func DeletedAtGT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldDeletedAt, v))
}
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
func DeletedAtGTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldDeletedAt, v))
}
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
func DeletedAtLT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldDeletedAt, v))
}
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
func DeletedAtLTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldDeletedAt, v))
}
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
func DeletedAtIsNil() predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIsNull(FieldDeletedAt))
}
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
func DeletedAtNotNil() predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotNull(FieldDeletedAt))
}
// KeyEQ applies the EQ predicate on the "key" field.
func KeyEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldKey, v))
}
// KeyNEQ applies the NEQ predicate on the "key" field.
func KeyNEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldKey, v))
}
// KeyIn applies the In predicate on the "key" field.
func KeyIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldKey, vs...))
}
// KeyNotIn applies the NotIn predicate on the "key" field.
func KeyNotIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldKey, vs...))
}
// KeyGT applies the GT predicate on the "key" field.
func KeyGT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldKey, v))
}
// KeyGTE applies the GTE predicate on the "key" field.
func KeyGTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldKey, v))
}
// KeyLT applies the LT predicate on the "key" field.
func KeyLT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldKey, v))
}
// KeyLTE applies the LTE predicate on the "key" field.
func KeyLTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldKey, v))
}
// KeyContains applies the Contains predicate on the "key" field.
func KeyContains(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContains(FieldKey, v))
}
// KeyHasPrefix applies the HasPrefix predicate on the "key" field.
func KeyHasPrefix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldKey, v))
}
// KeyHasSuffix applies the HasSuffix predicate on the "key" field.
func KeyHasSuffix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldKey, v))
}
// KeyEqualFold applies the EqualFold predicate on the "key" field.
func KeyEqualFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldKey, v))
}
// KeyContainsFold applies the ContainsFold predicate on the "key" field.
func KeyContainsFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldKey, v))
}
// NameEQ applies the EQ predicate on the "name" field.
func NameEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldName, v))
}
// NameNEQ applies the NEQ predicate on the "name" field.
func NameNEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldName, v))
}
// NameIn applies the In predicate on the "name" field.
func NameIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldName, vs...))
}
// NameNotIn applies the NotIn predicate on the "name" field.
func NameNotIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldName, vs...))
}
// NameGT applies the GT predicate on the "name" field.
func NameGT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldName, v))
}
// NameGTE applies the GTE predicate on the "name" field.
func NameGTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldName, v))
}
// NameLT applies the LT predicate on the "name" field.
func NameLT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldName, v))
}
// NameLTE applies the LTE predicate on the "name" field.
func NameLTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldName, v))
}
// NameContains applies the Contains predicate on the "name" field.
func NameContains(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContains(FieldName, v))
}
// NameHasPrefix applies the HasPrefix predicate on the "name" field.
func NameHasPrefix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldName, v))
}
// NameHasSuffix applies the HasSuffix predicate on the "name" field.
func NameHasSuffix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldName, v))
}
// NameEqualFold applies the EqualFold predicate on the "name" field.
func NameEqualFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldName, v))
}
// NameContainsFold applies the ContainsFold predicate on the "name" field.
func NameContainsFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldName, v))
}
// DescriptionEQ applies the EQ predicate on the "description" field.
func DescriptionEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDescription, v))
}
// DescriptionNEQ applies the NEQ predicate on the "description" field.
func DescriptionNEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldDescription, v))
}
// DescriptionIn applies the In predicate on the "description" field.
func DescriptionIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldDescription, vs...))
}
// DescriptionNotIn applies the NotIn predicate on the "description" field.
func DescriptionNotIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldDescription, vs...))
}
// DescriptionGT applies the GT predicate on the "description" field.
func DescriptionGT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldDescription, v))
}
// DescriptionGTE applies the GTE predicate on the "description" field.
func DescriptionGTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldDescription, v))
}
// DescriptionLT applies the LT predicate on the "description" field.
func DescriptionLT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldDescription, v))
}
// DescriptionLTE applies the LTE predicate on the "description" field.
func DescriptionLTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldDescription, v))
}
// DescriptionContains applies the Contains predicate on the "description" field.
func DescriptionContains(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContains(FieldDescription, v))
}
// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
func DescriptionHasPrefix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldDescription, v))
}
// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
func DescriptionHasSuffix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldDescription, v))
}
// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
func DescriptionEqualFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldDescription, v))
}
// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
func DescriptionContainsFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldDescription, v))
}
// TypeEQ applies the EQ predicate on the "type" field.
func TypeEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldType, v))
}
// TypeNEQ applies the NEQ predicate on the "type" field.
func TypeNEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldType, v))
}
// TypeIn applies the In predicate on the "type" field.
func TypeIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldType, vs...))
}
// TypeNotIn applies the NotIn predicate on the "type" field.
func TypeNotIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldType, vs...))
}
// TypeGT applies the GT predicate on the "type" field.
func TypeGT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldType, v))
}
// TypeGTE applies the GTE predicate on the "type" field.
func TypeGTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldType, v))
}
// TypeLT applies the LT predicate on the "type" field.
func TypeLT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldType, v))
}
// TypeLTE applies the LTE predicate on the "type" field.
func TypeLTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldType, v))
}
// TypeContains applies the Contains predicate on the "type" field.
func TypeContains(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContains(FieldType, v))
}
// TypeHasPrefix applies the HasPrefix predicate on the "type" field.
func TypeHasPrefix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldType, v))
}
// TypeHasSuffix applies the HasSuffix predicate on the "type" field.
func TypeHasSuffix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldType, v))
}
// TypeEqualFold applies the EqualFold predicate on the "type" field.
func TypeEqualFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldType, v))
}
// TypeContainsFold applies the ContainsFold predicate on the "type" field.
func TypeContainsFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldType, v))
}
// RequiredEQ applies the EQ predicate on the "required" field.
func RequiredEQ(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldRequired, v))
}
// RequiredNEQ applies the NEQ predicate on the "required" field.
func RequiredNEQ(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldRequired, v))
}
// PlaceholderEQ applies the EQ predicate on the "placeholder" field.
func PlaceholderEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldPlaceholder, v))
}
// PlaceholderNEQ applies the NEQ predicate on the "placeholder" field.
func PlaceholderNEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldPlaceholder, v))
}
// PlaceholderIn applies the In predicate on the "placeholder" field.
func PlaceholderIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldPlaceholder, vs...))
}
// PlaceholderNotIn applies the NotIn predicate on the "placeholder" field.
func PlaceholderNotIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldPlaceholder, vs...))
}
// PlaceholderGT applies the GT predicate on the "placeholder" field.
func PlaceholderGT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldPlaceholder, v))
}
// PlaceholderGTE applies the GTE predicate on the "placeholder" field.
func PlaceholderGTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldPlaceholder, v))
}
// PlaceholderLT applies the LT predicate on the "placeholder" field.
func PlaceholderLT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldPlaceholder, v))
}
// PlaceholderLTE applies the LTE predicate on the "placeholder" field.
func PlaceholderLTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldPlaceholder, v))
}
// PlaceholderContains applies the Contains predicate on the "placeholder" field.
func PlaceholderContains(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContains(FieldPlaceholder, v))
}
// PlaceholderHasPrefix applies the HasPrefix predicate on the "placeholder" field.
func PlaceholderHasPrefix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldPlaceholder, v))
}
// PlaceholderHasSuffix applies the HasSuffix predicate on the "placeholder" field.
func PlaceholderHasSuffix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldPlaceholder, v))
}
// PlaceholderEqualFold applies the EqualFold predicate on the "placeholder" field.
func PlaceholderEqualFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldPlaceholder, v))
}
// PlaceholderContainsFold applies the ContainsFold predicate on the "placeholder" field.
func PlaceholderContainsFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldPlaceholder, v))
}
// DisplayOrderEQ applies the EQ predicate on the "display_order" field.
func DisplayOrderEQ(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDisplayOrder, v))
}
// DisplayOrderNEQ applies the NEQ predicate on the "display_order" field.
func DisplayOrderNEQ(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldDisplayOrder, v))
}
// DisplayOrderIn applies the In predicate on the "display_order" field.
func DisplayOrderIn(vs ...int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldDisplayOrder, vs...))
}
// DisplayOrderNotIn applies the NotIn predicate on the "display_order" field.
func DisplayOrderNotIn(vs ...int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldDisplayOrder, vs...))
}
// DisplayOrderGT applies the GT predicate on the "display_order" field.
func DisplayOrderGT(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldDisplayOrder, v))
}
// DisplayOrderGTE applies the GTE predicate on the "display_order" field.
func DisplayOrderGTE(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldDisplayOrder, v))
}
// DisplayOrderLT applies the LT predicate on the "display_order" field.
func DisplayOrderLT(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldDisplayOrder, v))
}
// DisplayOrderLTE applies the LTE predicate on the "display_order" field.
func DisplayOrderLTE(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldDisplayOrder, v))
}
// EnabledEQ applies the EQ predicate on the "enabled" field.
func EnabledEQ(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldEnabled, v))
}
// EnabledNEQ applies the NEQ predicate on the "enabled" field.
func EnabledNEQ(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldEnabled, v))
}
// HasValues applies the HasEdge predicate on the "values" edge.
func HasValues() predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, ValuesTable, ValuesColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasValuesWith applies the HasEdge predicate on the "values" edge with a given conditions (other predicates).
func HasValuesWith(preds ...predicate.UserAttributeValue) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(func(s *sql.Selector) {
step := newValuesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UserAttributeDefinition) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.UserAttributeDefinition) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.UserAttributeDefinition) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.NotPredicates(p))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
)
// UserAttributeDefinitionDelete is the builder for deleting a UserAttributeDefinition entity.
type UserAttributeDefinitionDelete struct {
config
hooks []Hook
mutation *UserAttributeDefinitionMutation
}
// Where appends a list predicates to the UserAttributeDefinitionDelete builder.
func (_d *UserAttributeDefinitionDelete) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UserAttributeDefinitionDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserAttributeDefinitionDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UserAttributeDefinitionDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(userattributedefinition.Table, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UserAttributeDefinitionDeleteOne is the builder for deleting a single UserAttributeDefinition entity.
type UserAttributeDefinitionDeleteOne struct {
_d *UserAttributeDefinitionDelete
}
// Where appends a list predicates to the UserAttributeDefinitionDelete builder.
func (_d *UserAttributeDefinitionDeleteOne) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UserAttributeDefinitionDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{userattributedefinition.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserAttributeDefinitionDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,606 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"database/sql/driver"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeDefinitionQuery is the builder for querying UserAttributeDefinition entities.
type UserAttributeDefinitionQuery struct {
config
ctx *QueryContext
order []userattributedefinition.OrderOption
inters []Interceptor
predicates []predicate.UserAttributeDefinition
withValues *UserAttributeValueQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UserAttributeDefinitionQuery builder.
func (_q *UserAttributeDefinitionQuery) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UserAttributeDefinitionQuery) Limit(limit int) *UserAttributeDefinitionQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UserAttributeDefinitionQuery) Offset(offset int) *UserAttributeDefinitionQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UserAttributeDefinitionQuery) Unique(unique bool) *UserAttributeDefinitionQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UserAttributeDefinitionQuery) Order(o ...userattributedefinition.OrderOption) *UserAttributeDefinitionQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryValues chains the current query on the "values" edge.
func (_q *UserAttributeDefinitionQuery) QueryValues() *UserAttributeValueQuery {
query := (&UserAttributeValueClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(userattributedefinition.Table, userattributedefinition.FieldID, selector),
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, userattributedefinition.ValuesTable, userattributedefinition.ValuesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserAttributeDefinition entity from the query.
// Returns a *NotFoundError when no UserAttributeDefinition was found.
func (_q *UserAttributeDefinitionQuery) First(ctx context.Context) (*UserAttributeDefinition, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{userattributedefinition.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) FirstX(ctx context.Context) *UserAttributeDefinition {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UserAttributeDefinition ID from the query.
// Returns a *NotFoundError when no UserAttributeDefinition ID was found.
func (_q *UserAttributeDefinitionQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{userattributedefinition.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UserAttributeDefinition entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UserAttributeDefinition entity is found.
// Returns a *NotFoundError when no UserAttributeDefinition entities are found.
func (_q *UserAttributeDefinitionQuery) Only(ctx context.Context) (*UserAttributeDefinition, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{userattributedefinition.Label}
default:
return nil, &NotSingularError{userattributedefinition.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) OnlyX(ctx context.Context) *UserAttributeDefinition {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UserAttributeDefinition ID in the query.
// Returns a *NotSingularError when more than one UserAttributeDefinition ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UserAttributeDefinitionQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{userattributedefinition.Label}
default:
err = &NotSingularError{userattributedefinition.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UserAttributeDefinitions.
func (_q *UserAttributeDefinitionQuery) All(ctx context.Context) ([]*UserAttributeDefinition, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UserAttributeDefinition, *UserAttributeDefinitionQuery]()
return withInterceptors[[]*UserAttributeDefinition](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) AllX(ctx context.Context) []*UserAttributeDefinition {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UserAttributeDefinition IDs.
func (_q *UserAttributeDefinitionQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(userattributedefinition.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UserAttributeDefinitionQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UserAttributeDefinitionQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UserAttributeDefinitionQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UserAttributeDefinitionQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UserAttributeDefinitionQuery) Clone() *UserAttributeDefinitionQuery {
if _q == nil {
return nil
}
return &UserAttributeDefinitionQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]userattributedefinition.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UserAttributeDefinition{}, _q.predicates...),
withValues: _q.withValues.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithValues tells the query-builder to eager-load the nodes that are connected to
// the "values" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserAttributeDefinitionQuery) WithValues(opts ...func(*UserAttributeValueQuery)) *UserAttributeDefinitionQuery {
query := (&UserAttributeValueClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withValues = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UserAttributeDefinition.Query().
// GroupBy(userattributedefinition.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UserAttributeDefinitionQuery) GroupBy(field string, fields ...string) *UserAttributeDefinitionGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UserAttributeDefinitionGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = userattributedefinition.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.UserAttributeDefinition.Query().
// Select(userattributedefinition.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *UserAttributeDefinitionQuery) Select(fields ...string) *UserAttributeDefinitionSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UserAttributeDefinitionSelect{UserAttributeDefinitionQuery: _q}
sbuild.label = userattributedefinition.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UserAttributeDefinitionSelect configured with the given aggregations.
func (_q *UserAttributeDefinitionQuery) Aggregate(fns ...AggregateFunc) *UserAttributeDefinitionSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UserAttributeDefinitionQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !userattributedefinition.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UserAttributeDefinitionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserAttributeDefinition, error) {
var (
nodes = []*UserAttributeDefinition{}
_spec = _q.querySpec()
loadedTypes = [1]bool{
_q.withValues != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UserAttributeDefinition).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UserAttributeDefinition{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withValues; query != nil {
if err := _q.loadValues(ctx, query, nodes,
func(n *UserAttributeDefinition) { n.Edges.Values = []*UserAttributeValue{} },
func(n *UserAttributeDefinition, e *UserAttributeValue) { n.Edges.Values = append(n.Edges.Values, e) }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *UserAttributeDefinitionQuery) loadValues(ctx context.Context, query *UserAttributeValueQuery, nodes []*UserAttributeDefinition, init func(*UserAttributeDefinition), assign func(*UserAttributeDefinition, *UserAttributeValue)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*UserAttributeDefinition)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(userattributevalue.FieldAttributeID)
}
query.Where(predicate.UserAttributeValue(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(userattributedefinition.ValuesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.AttributeID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "attribute_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserAttributeDefinitionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UserAttributeDefinitionQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(userattributedefinition.Table, userattributedefinition.Columns, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userattributedefinition.FieldID)
for i := range fields {
if fields[i] != userattributedefinition.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(userattributedefinition.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = userattributedefinition.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// UserAttributeDefinitionGroupBy is the group-by builder for UserAttributeDefinition entities.
type UserAttributeDefinitionGroupBy struct {
selector
build *UserAttributeDefinitionQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UserAttributeDefinitionGroupBy) Aggregate(fns ...AggregateFunc) *UserAttributeDefinitionGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UserAttributeDefinitionGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserAttributeDefinitionQuery, *UserAttributeDefinitionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UserAttributeDefinitionGroupBy) sqlScan(ctx context.Context, root *UserAttributeDefinitionQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UserAttributeDefinitionSelect is the builder for selecting fields of UserAttributeDefinition entities.
type UserAttributeDefinitionSelect struct {
*UserAttributeDefinitionQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UserAttributeDefinitionSelect) Aggregate(fns ...AggregateFunc) *UserAttributeDefinitionSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UserAttributeDefinitionSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserAttributeDefinitionQuery, *UserAttributeDefinitionSelect](ctx, _s.UserAttributeDefinitionQuery, _s, _s.inters, v)
}
func (_s *UserAttributeDefinitionSelect) sqlScan(ctx context.Context, root *UserAttributeDefinitionQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,846 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeDefinitionUpdate is the builder for updating UserAttributeDefinition entities.
type UserAttributeDefinitionUpdate struct {
config
hooks []Hook
mutation *UserAttributeDefinitionMutation
}
// Where appends a list predicates to the UserAttributeDefinitionUpdate builder.
func (_u *UserAttributeDefinitionUpdate) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserAttributeDefinitionUpdate) SetUpdatedAt(v time.Time) *UserAttributeDefinitionUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserAttributeDefinitionUpdate) SetDeletedAt(v time.Time) *UserAttributeDefinitionUpdate {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableDeletedAt(v *time.Time) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserAttributeDefinitionUpdate) ClearDeletedAt() *UserAttributeDefinitionUpdate {
_u.mutation.ClearDeletedAt()
return _u
}
// SetKey sets the "key" field.
func (_u *UserAttributeDefinitionUpdate) SetKey(v string) *UserAttributeDefinitionUpdate {
_u.mutation.SetKey(v)
return _u
}
// SetNillableKey sets the "key" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableKey(v *string) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetKey(*v)
}
return _u
}
// SetName sets the "name" field.
func (_u *UserAttributeDefinitionUpdate) SetName(v string) *UserAttributeDefinitionUpdate {
_u.mutation.SetName(v)
return _u
}
// SetNillableName sets the "name" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableName(v *string) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetName(*v)
}
return _u
}
// SetDescription sets the "description" field.
func (_u *UserAttributeDefinitionUpdate) SetDescription(v string) *UserAttributeDefinitionUpdate {
_u.mutation.SetDescription(v)
return _u
}
// SetNillableDescription sets the "description" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableDescription(v *string) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetDescription(*v)
}
return _u
}
// SetType sets the "type" field.
func (_u *UserAttributeDefinitionUpdate) SetType(v string) *UserAttributeDefinitionUpdate {
_u.mutation.SetType(v)
return _u
}
// SetNillableType sets the "type" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableType(v *string) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetType(*v)
}
return _u
}
// SetOptions sets the "options" field.
func (_u *UserAttributeDefinitionUpdate) SetOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdate {
_u.mutation.SetOptions(v)
return _u
}
// AppendOptions appends value to the "options" field.
func (_u *UserAttributeDefinitionUpdate) AppendOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdate {
_u.mutation.AppendOptions(v)
return _u
}
// SetRequired sets the "required" field.
func (_u *UserAttributeDefinitionUpdate) SetRequired(v bool) *UserAttributeDefinitionUpdate {
_u.mutation.SetRequired(v)
return _u
}
// SetNillableRequired sets the "required" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableRequired(v *bool) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetRequired(*v)
}
return _u
}
// SetValidation sets the "validation" field.
func (_u *UserAttributeDefinitionUpdate) SetValidation(v map[string]interface{}) *UserAttributeDefinitionUpdate {
_u.mutation.SetValidation(v)
return _u
}
// SetPlaceholder sets the "placeholder" field.
func (_u *UserAttributeDefinitionUpdate) SetPlaceholder(v string) *UserAttributeDefinitionUpdate {
_u.mutation.SetPlaceholder(v)
return _u
}
// SetNillablePlaceholder sets the "placeholder" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillablePlaceholder(v *string) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetPlaceholder(*v)
}
return _u
}
// SetDisplayOrder sets the "display_order" field.
func (_u *UserAttributeDefinitionUpdate) SetDisplayOrder(v int) *UserAttributeDefinitionUpdate {
_u.mutation.ResetDisplayOrder()
_u.mutation.SetDisplayOrder(v)
return _u
}
// SetNillableDisplayOrder sets the "display_order" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableDisplayOrder(v *int) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetDisplayOrder(*v)
}
return _u
}
// AddDisplayOrder adds value to the "display_order" field.
func (_u *UserAttributeDefinitionUpdate) AddDisplayOrder(v int) *UserAttributeDefinitionUpdate {
_u.mutation.AddDisplayOrder(v)
return _u
}
// SetEnabled sets the "enabled" field.
func (_u *UserAttributeDefinitionUpdate) SetEnabled(v bool) *UserAttributeDefinitionUpdate {
_u.mutation.SetEnabled(v)
return _u
}
// SetNillableEnabled sets the "enabled" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableEnabled(v *bool) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetEnabled(*v)
}
return _u
}
// AddValueIDs adds the "values" edge to the UserAttributeValue entity by IDs.
func (_u *UserAttributeDefinitionUpdate) AddValueIDs(ids ...int64) *UserAttributeDefinitionUpdate {
_u.mutation.AddValueIDs(ids...)
return _u
}
// AddValues adds the "values" edges to the UserAttributeValue entity.
func (_u *UserAttributeDefinitionUpdate) AddValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddValueIDs(ids...)
}
// Mutation returns the UserAttributeDefinitionMutation object of the builder.
func (_u *UserAttributeDefinitionUpdate) Mutation() *UserAttributeDefinitionMutation {
return _u.mutation
}
// ClearValues clears all "values" edges to the UserAttributeValue entity.
func (_u *UserAttributeDefinitionUpdate) ClearValues() *UserAttributeDefinitionUpdate {
_u.mutation.ClearValues()
return _u
}
// RemoveValueIDs removes the "values" edge to UserAttributeValue entities by IDs.
func (_u *UserAttributeDefinitionUpdate) RemoveValueIDs(ids ...int64) *UserAttributeDefinitionUpdate {
_u.mutation.RemoveValueIDs(ids...)
return _u
}
// RemoveValues removes "values" edges to UserAttributeValue entities.
func (_u *UserAttributeDefinitionUpdate) RemoveValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveValueIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserAttributeDefinitionUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
return 0, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserAttributeDefinitionUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *UserAttributeDefinitionUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserAttributeDefinitionUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserAttributeDefinitionUpdate) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if userattributedefinition.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized userattributedefinition.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := userattributedefinition.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserAttributeDefinitionUpdate) check() error {
if v, ok := _u.mutation.Key(); ok {
if err := userattributedefinition.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.key": %w`, err)}
}
}
if v, ok := _u.mutation.Name(); ok {
if err := userattributedefinition.NameValidator(v); err != nil {
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.name": %w`, err)}
}
}
if v, ok := _u.mutation.GetType(); ok {
if err := userattributedefinition.TypeValidator(v); err != nil {
return &ValidationError{Name: "type", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.type": %w`, err)}
}
}
if v, ok := _u.mutation.Placeholder(); ok {
if err := userattributedefinition.PlaceholderValidator(v); err != nil {
return &ValidationError{Name: "placeholder", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.placeholder": %w`, err)}
}
}
return nil
}
func (_u *UserAttributeDefinitionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userattributedefinition.Table, userattributedefinition.Columns, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userattributedefinition.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(userattributedefinition.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(userattributedefinition.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.Key(); ok {
_spec.SetField(userattributedefinition.FieldKey, field.TypeString, value)
}
if value, ok := _u.mutation.Name(); ok {
_spec.SetField(userattributedefinition.FieldName, field.TypeString, value)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(userattributedefinition.FieldDescription, field.TypeString, value)
}
if value, ok := _u.mutation.GetType(); ok {
_spec.SetField(userattributedefinition.FieldType, field.TypeString, value)
}
if value, ok := _u.mutation.Options(); ok {
_spec.SetField(userattributedefinition.FieldOptions, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedOptions(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, userattributedefinition.FieldOptions, value)
})
}
if value, ok := _u.mutation.Required(); ok {
_spec.SetField(userattributedefinition.FieldRequired, field.TypeBool, value)
}
if value, ok := _u.mutation.Validation(); ok {
_spec.SetField(userattributedefinition.FieldValidation, field.TypeJSON, value)
}
if value, ok := _u.mutation.Placeholder(); ok {
_spec.SetField(userattributedefinition.FieldPlaceholder, field.TypeString, value)
}
if value, ok := _u.mutation.DisplayOrder(); ok {
_spec.SetField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedDisplayOrder(); ok {
_spec.AddField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.Enabled(); ok {
_spec.SetField(userattributedefinition.FieldEnabled, field.TypeBool, value)
}
if _u.mutation.ValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedValuesIDs(); len(nodes) > 0 && !_u.mutation.ValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ValuesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userattributedefinition.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// UserAttributeDefinitionUpdateOne is the builder for updating a single UserAttributeDefinition entity.
type UserAttributeDefinitionUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UserAttributeDefinitionMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserAttributeDefinitionUpdateOne) SetUpdatedAt(v time.Time) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserAttributeDefinitionUpdateOne) SetDeletedAt(v time.Time) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableDeletedAt(v *time.Time) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserAttributeDefinitionUpdateOne) ClearDeletedAt() *UserAttributeDefinitionUpdateOne {
_u.mutation.ClearDeletedAt()
return _u
}
// SetKey sets the "key" field.
func (_u *UserAttributeDefinitionUpdateOne) SetKey(v string) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetKey(v)
return _u
}
// SetNillableKey sets the "key" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableKey(v *string) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetKey(*v)
}
return _u
}
// SetName sets the "name" field.
func (_u *UserAttributeDefinitionUpdateOne) SetName(v string) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetName(v)
return _u
}
// SetNillableName sets the "name" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableName(v *string) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetName(*v)
}
return _u
}
// SetDescription sets the "description" field.
func (_u *UserAttributeDefinitionUpdateOne) SetDescription(v string) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetDescription(v)
return _u
}
// SetNillableDescription sets the "description" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableDescription(v *string) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetDescription(*v)
}
return _u
}
// SetType sets the "type" field.
func (_u *UserAttributeDefinitionUpdateOne) SetType(v string) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetType(v)
return _u
}
// SetNillableType sets the "type" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableType(v *string) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetType(*v)
}
return _u
}
// SetOptions sets the "options" field.
func (_u *UserAttributeDefinitionUpdateOne) SetOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetOptions(v)
return _u
}
// AppendOptions appends value to the "options" field.
func (_u *UserAttributeDefinitionUpdateOne) AppendOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdateOne {
_u.mutation.AppendOptions(v)
return _u
}
// SetRequired sets the "required" field.
func (_u *UserAttributeDefinitionUpdateOne) SetRequired(v bool) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetRequired(v)
return _u
}
// SetNillableRequired sets the "required" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableRequired(v *bool) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetRequired(*v)
}
return _u
}
// SetValidation sets the "validation" field.
func (_u *UserAttributeDefinitionUpdateOne) SetValidation(v map[string]interface{}) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetValidation(v)
return _u
}
// SetPlaceholder sets the "placeholder" field.
func (_u *UserAttributeDefinitionUpdateOne) SetPlaceholder(v string) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetPlaceholder(v)
return _u
}
// SetNillablePlaceholder sets the "placeholder" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillablePlaceholder(v *string) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetPlaceholder(*v)
}
return _u
}
// SetDisplayOrder sets the "display_order" field.
func (_u *UserAttributeDefinitionUpdateOne) SetDisplayOrder(v int) *UserAttributeDefinitionUpdateOne {
_u.mutation.ResetDisplayOrder()
_u.mutation.SetDisplayOrder(v)
return _u
}
// SetNillableDisplayOrder sets the "display_order" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableDisplayOrder(v *int) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetDisplayOrder(*v)
}
return _u
}
// AddDisplayOrder adds value to the "display_order" field.
func (_u *UserAttributeDefinitionUpdateOne) AddDisplayOrder(v int) *UserAttributeDefinitionUpdateOne {
_u.mutation.AddDisplayOrder(v)
return _u
}
// SetEnabled sets the "enabled" field.
func (_u *UserAttributeDefinitionUpdateOne) SetEnabled(v bool) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetEnabled(v)
return _u
}
// SetNillableEnabled sets the "enabled" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableEnabled(v *bool) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetEnabled(*v)
}
return _u
}
// AddValueIDs adds the "values" edge to the UserAttributeValue entity by IDs.
func (_u *UserAttributeDefinitionUpdateOne) AddValueIDs(ids ...int64) *UserAttributeDefinitionUpdateOne {
_u.mutation.AddValueIDs(ids...)
return _u
}
// AddValues adds the "values" edges to the UserAttributeValue entity.
func (_u *UserAttributeDefinitionUpdateOne) AddValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddValueIDs(ids...)
}
// Mutation returns the UserAttributeDefinitionMutation object of the builder.
func (_u *UserAttributeDefinitionUpdateOne) Mutation() *UserAttributeDefinitionMutation {
return _u.mutation
}
// ClearValues clears all "values" edges to the UserAttributeValue entity.
func (_u *UserAttributeDefinitionUpdateOne) ClearValues() *UserAttributeDefinitionUpdateOne {
_u.mutation.ClearValues()
return _u
}
// RemoveValueIDs removes the "values" edge to UserAttributeValue entities by IDs.
func (_u *UserAttributeDefinitionUpdateOne) RemoveValueIDs(ids ...int64) *UserAttributeDefinitionUpdateOne {
_u.mutation.RemoveValueIDs(ids...)
return _u
}
// RemoveValues removes "values" edges to UserAttributeValue entities.
func (_u *UserAttributeDefinitionUpdateOne) RemoveValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveValueIDs(ids...)
}
// Where appends a list predicates to the UserAttributeDefinitionUpdate builder.
func (_u *UserAttributeDefinitionUpdateOne) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *UserAttributeDefinitionUpdateOne) Select(field string, fields ...string) *UserAttributeDefinitionUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated UserAttributeDefinition entity.
func (_u *UserAttributeDefinitionUpdateOne) Save(ctx context.Context) (*UserAttributeDefinition, error) {
if err := _u.defaults(); err != nil {
return nil, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserAttributeDefinitionUpdateOne) SaveX(ctx context.Context) *UserAttributeDefinition {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *UserAttributeDefinitionUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserAttributeDefinitionUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserAttributeDefinitionUpdateOne) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if userattributedefinition.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized userattributedefinition.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := userattributedefinition.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserAttributeDefinitionUpdateOne) check() error {
if v, ok := _u.mutation.Key(); ok {
if err := userattributedefinition.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.key": %w`, err)}
}
}
if v, ok := _u.mutation.Name(); ok {
if err := userattributedefinition.NameValidator(v); err != nil {
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.name": %w`, err)}
}
}
if v, ok := _u.mutation.GetType(); ok {
if err := userattributedefinition.TypeValidator(v); err != nil {
return &ValidationError{Name: "type", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.type": %w`, err)}
}
}
if v, ok := _u.mutation.Placeholder(); ok {
if err := userattributedefinition.PlaceholderValidator(v); err != nil {
return &ValidationError{Name: "placeholder", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.placeholder": %w`, err)}
}
}
return nil
}
func (_u *UserAttributeDefinitionUpdateOne) sqlSave(ctx context.Context) (_node *UserAttributeDefinition, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userattributedefinition.Table, userattributedefinition.Columns, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserAttributeDefinition.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userattributedefinition.FieldID)
for _, f := range fields {
if !userattributedefinition.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != userattributedefinition.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userattributedefinition.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(userattributedefinition.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(userattributedefinition.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.Key(); ok {
_spec.SetField(userattributedefinition.FieldKey, field.TypeString, value)
}
if value, ok := _u.mutation.Name(); ok {
_spec.SetField(userattributedefinition.FieldName, field.TypeString, value)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(userattributedefinition.FieldDescription, field.TypeString, value)
}
if value, ok := _u.mutation.GetType(); ok {
_spec.SetField(userattributedefinition.FieldType, field.TypeString, value)
}
if value, ok := _u.mutation.Options(); ok {
_spec.SetField(userattributedefinition.FieldOptions, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedOptions(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, userattributedefinition.FieldOptions, value)
})
}
if value, ok := _u.mutation.Required(); ok {
_spec.SetField(userattributedefinition.FieldRequired, field.TypeBool, value)
}
if value, ok := _u.mutation.Validation(); ok {
_spec.SetField(userattributedefinition.FieldValidation, field.TypeJSON, value)
}
if value, ok := _u.mutation.Placeholder(); ok {
_spec.SetField(userattributedefinition.FieldPlaceholder, field.TypeString, value)
}
if value, ok := _u.mutation.DisplayOrder(); ok {
_spec.SetField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedDisplayOrder(); ok {
_spec.AddField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.Enabled(); ok {
_spec.SetField(userattributedefinition.FieldEnabled, field.TypeBool, value)
}
if _u.mutation.ValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedValuesIDs(); len(nodes) > 0 && !_u.mutation.ValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ValuesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserAttributeDefinition{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userattributedefinition.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -0,0 +1,198 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeValue is the model entity for the UserAttributeValue schema.
type UserAttributeValue struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// UserID holds the value of the "user_id" field.
UserID int64 `json:"user_id,omitempty"`
// AttributeID holds the value of the "attribute_id" field.
AttributeID int64 `json:"attribute_id,omitempty"`
// Value holds the value of the "value" field.
Value string `json:"value,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserAttributeValueQuery when eager-loading is set.
Edges UserAttributeValueEdges `json:"edges"`
selectValues sql.SelectValues
}
// UserAttributeValueEdges holds the relations/edges for other nodes in the graph.
type UserAttributeValueEdges struct {
// User holds the value of the user edge.
User *User `json:"user,omitempty"`
// Definition holds the value of the definition edge.
Definition *UserAttributeDefinition `json:"definition,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [2]bool
}
// UserOrErr returns the User value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UserAttributeValueEdges) UserOrErr() (*User, error) {
if e.User != nil {
return e.User, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: user.Label}
}
return nil, &NotLoadedError{edge: "user"}
}
// DefinitionOrErr returns the Definition value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UserAttributeValueEdges) DefinitionOrErr() (*UserAttributeDefinition, error) {
if e.Definition != nil {
return e.Definition, nil
} else if e.loadedTypes[1] {
return nil, &NotFoundError{label: userattributedefinition.Label}
}
return nil, &NotLoadedError{edge: "definition"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*UserAttributeValue) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case userattributevalue.FieldID, userattributevalue.FieldUserID, userattributevalue.FieldAttributeID:
values[i] = new(sql.NullInt64)
case userattributevalue.FieldValue:
values[i] = new(sql.NullString)
case userattributevalue.FieldCreatedAt, userattributevalue.FieldUpdatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the UserAttributeValue fields.
func (_m *UserAttributeValue) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case userattributevalue.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case userattributevalue.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case userattributevalue.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case userattributevalue.FieldUserID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field user_id", values[i])
} else if value.Valid {
_m.UserID = value.Int64
}
case userattributevalue.FieldAttributeID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field attribute_id", values[i])
} else if value.Valid {
_m.AttributeID = value.Int64
}
case userattributevalue.FieldValue:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field value", values[i])
} else if value.Valid {
_m.Value = value.String
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// GetValue returns the ent.Value that was dynamically selected and assigned to the UserAttributeValue.
// This includes values selected through modifiers, order, etc.
func (_m *UserAttributeValue) GetValue(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryUser queries the "user" edge of the UserAttributeValue entity.
func (_m *UserAttributeValue) QueryUser() *UserQuery {
return NewUserAttributeValueClient(_m.config).QueryUser(_m)
}
// QueryDefinition queries the "definition" edge of the UserAttributeValue entity.
func (_m *UserAttributeValue) QueryDefinition() *UserAttributeDefinitionQuery {
return NewUserAttributeValueClient(_m.config).QueryDefinition(_m)
}
// Update returns a builder for updating this UserAttributeValue.
// Note that you need to call UserAttributeValue.Unwrap() before calling this method if this UserAttributeValue
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *UserAttributeValue) Update() *UserAttributeValueUpdateOne {
return NewUserAttributeValueClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the UserAttributeValue entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *UserAttributeValue) Unwrap() *UserAttributeValue {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: UserAttributeValue is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *UserAttributeValue) String() string {
var builder strings.Builder
builder.WriteString("UserAttributeValue(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("user_id=")
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
builder.WriteString(", ")
builder.WriteString("attribute_id=")
builder.WriteString(fmt.Sprintf("%v", _m.AttributeID))
builder.WriteString(", ")
builder.WriteString("value=")
builder.WriteString(_m.Value)
builder.WriteByte(')')
return builder.String()
}
// UserAttributeValues is a parsable slice of UserAttributeValue.
type UserAttributeValues []*UserAttributeValue

View File

@@ -0,0 +1,139 @@
// Code generated by ent, DO NOT EDIT.
package userattributevalue
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the userattributevalue type in the database.
Label = "user_attribute_value"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldAttributeID holds the string denoting the attribute_id field in the database.
FieldAttributeID = "attribute_id"
// FieldValue holds the string denoting the value field in the database.
FieldValue = "value"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// EdgeDefinition holds the string denoting the definition edge name in mutations.
EdgeDefinition = "definition"
// Table holds the table name of the userattributevalue in the database.
Table = "user_attribute_values"
// UserTable is the table that holds the user relation/edge.
UserTable = "user_attribute_values"
// UserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
UserInverseTable = "users"
// UserColumn is the table column denoting the user relation/edge.
UserColumn = "user_id"
// DefinitionTable is the table that holds the definition relation/edge.
DefinitionTable = "user_attribute_values"
// DefinitionInverseTable is the table name for the UserAttributeDefinition entity.
// It exists in this package in order to avoid circular dependency with the "userattributedefinition" package.
DefinitionInverseTable = "user_attribute_definitions"
// DefinitionColumn is the table column denoting the definition relation/edge.
DefinitionColumn = "attribute_id"
)
// Columns holds all SQL columns for userattributevalue fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldUserID,
FieldAttributeID,
FieldValue,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// DefaultValue holds the default value on creation for the "value" field.
DefaultValue string
)
// OrderOption defines the ordering options for the UserAttributeValue queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
}
// ByAttributeID orders the results by the attribute_id field.
func ByAttributeID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAttributeID, opts...).ToFunc()
}
// ByValue orders the results by the value field.
func ByValue(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldValue, opts...).ToFunc()
}
// ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
}
}
// ByDefinitionField orders the results by definition field.
func ByDefinitionField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newDefinitionStep(), sql.OrderByField(field, opts...))
}
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
}
func newDefinitionStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(DefinitionInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, DefinitionTable, DefinitionColumn),
)
}

View File

@@ -0,0 +1,327 @@
// Code generated by ent, DO NOT EDIT.
package userattributevalue
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldUpdatedAt, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldUserID, v))
}
// AttributeID applies equality check predicate on the "attribute_id" field. It's identical to AttributeIDEQ.
func AttributeID(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldAttributeID, v))
}
// Value applies equality check predicate on the "value" field. It's identical to ValueEQ.
func Value(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldValue, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLTE(FieldUpdatedAt, v))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldUserID, v))
}
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
func UserIDNEQ(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldUserID, v))
}
// UserIDIn applies the In predicate on the "user_id" field.
func UserIDIn(vs ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldUserID, vs...))
}
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
func UserIDNotIn(vs ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldUserID, vs...))
}
// AttributeIDEQ applies the EQ predicate on the "attribute_id" field.
func AttributeIDEQ(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldAttributeID, v))
}
// AttributeIDNEQ applies the NEQ predicate on the "attribute_id" field.
func AttributeIDNEQ(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldAttributeID, v))
}
// AttributeIDIn applies the In predicate on the "attribute_id" field.
func AttributeIDIn(vs ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldAttributeID, vs...))
}
// AttributeIDNotIn applies the NotIn predicate on the "attribute_id" field.
func AttributeIDNotIn(vs ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldAttributeID, vs...))
}
// ValueEQ applies the EQ predicate on the "value" field.
func ValueEQ(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldValue, v))
}
// ValueNEQ applies the NEQ predicate on the "value" field.
func ValueNEQ(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldValue, v))
}
// ValueIn applies the In predicate on the "value" field.
func ValueIn(vs ...string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldValue, vs...))
}
// ValueNotIn applies the NotIn predicate on the "value" field.
func ValueNotIn(vs ...string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldValue, vs...))
}
// ValueGT applies the GT predicate on the "value" field.
func ValueGT(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGT(FieldValue, v))
}
// ValueGTE applies the GTE predicate on the "value" field.
func ValueGTE(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGTE(FieldValue, v))
}
// ValueLT applies the LT predicate on the "value" field.
func ValueLT(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLT(FieldValue, v))
}
// ValueLTE applies the LTE predicate on the "value" field.
func ValueLTE(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLTE(FieldValue, v))
}
// ValueContains applies the Contains predicate on the "value" field.
func ValueContains(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldContains(FieldValue, v))
}
// ValueHasPrefix applies the HasPrefix predicate on the "value" field.
func ValueHasPrefix(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldHasPrefix(FieldValue, v))
}
// ValueHasSuffix applies the HasSuffix predicate on the "value" field.
func ValueHasSuffix(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldHasSuffix(FieldValue, v))
}
// ValueEqualFold applies the EqualFold predicate on the "value" field.
func ValueEqualFold(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEqualFold(FieldValue, v))
}
// ValueContainsFold applies the ContainsFold predicate on the "value" field.
func ValueContainsFold(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldContainsFold(FieldValue, v))
}
// HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.UserAttributeValue {
return predicate.UserAttributeValue(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
func HasUserWith(preds ...predicate.User) predicate.UserAttributeValue {
return predicate.UserAttributeValue(func(s *sql.Selector) {
step := newUserStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasDefinition applies the HasEdge predicate on the "definition" edge.
func HasDefinition() predicate.UserAttributeValue {
return predicate.UserAttributeValue(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, DefinitionTable, DefinitionColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasDefinitionWith applies the HasEdge predicate on the "definition" edge with a given conditions (other predicates).
func HasDefinitionWith(preds ...predicate.UserAttributeDefinition) predicate.UserAttributeValue {
return predicate.UserAttributeValue(func(s *sql.Selector) {
step := newDefinitionStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UserAttributeValue) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.UserAttributeValue) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.UserAttributeValue) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.NotPredicates(p))
}

View File

@@ -0,0 +1,731 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeValueCreate is the builder for creating a UserAttributeValue entity.
type UserAttributeValueCreate struct {
config
mutation *UserAttributeValueMutation
hooks []Hook
conflict []sql.ConflictOption
}
// SetCreatedAt sets the "created_at" field.
func (_c *UserAttributeValueCreate) SetCreatedAt(v time.Time) *UserAttributeValueCreate {
_c.mutation.SetCreatedAt(v)
return _c
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_c *UserAttributeValueCreate) SetNillableCreatedAt(v *time.Time) *UserAttributeValueCreate {
if v != nil {
_c.SetCreatedAt(*v)
}
return _c
}
// SetUpdatedAt sets the "updated_at" field.
func (_c *UserAttributeValueCreate) SetUpdatedAt(v time.Time) *UserAttributeValueCreate {
_c.mutation.SetUpdatedAt(v)
return _c
}
// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
func (_c *UserAttributeValueCreate) SetNillableUpdatedAt(v *time.Time) *UserAttributeValueCreate {
if v != nil {
_c.SetUpdatedAt(*v)
}
return _c
}
// SetUserID sets the "user_id" field.
func (_c *UserAttributeValueCreate) SetUserID(v int64) *UserAttributeValueCreate {
_c.mutation.SetUserID(v)
return _c
}
// SetAttributeID sets the "attribute_id" field.
func (_c *UserAttributeValueCreate) SetAttributeID(v int64) *UserAttributeValueCreate {
_c.mutation.SetAttributeID(v)
return _c
}
// SetValue sets the "value" field.
func (_c *UserAttributeValueCreate) SetValue(v string) *UserAttributeValueCreate {
_c.mutation.SetValue(v)
return _c
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_c *UserAttributeValueCreate) SetNillableValue(v *string) *UserAttributeValueCreate {
if v != nil {
_c.SetValue(*v)
}
return _c
}
// SetUser sets the "user" edge to the User entity.
func (_c *UserAttributeValueCreate) SetUser(v *User) *UserAttributeValueCreate {
return _c.SetUserID(v.ID)
}
// SetDefinitionID sets the "definition" edge to the UserAttributeDefinition entity by ID.
func (_c *UserAttributeValueCreate) SetDefinitionID(id int64) *UserAttributeValueCreate {
_c.mutation.SetDefinitionID(id)
return _c
}
// SetDefinition sets the "definition" edge to the UserAttributeDefinition entity.
func (_c *UserAttributeValueCreate) SetDefinition(v *UserAttributeDefinition) *UserAttributeValueCreate {
return _c.SetDefinitionID(v.ID)
}
// Mutation returns the UserAttributeValueMutation object of the builder.
func (_c *UserAttributeValueCreate) Mutation() *UserAttributeValueMutation {
return _c.mutation
}
// Save creates the UserAttributeValue in the database.
func (_c *UserAttributeValueCreate) Save(ctx context.Context) (*UserAttributeValue, error) {
_c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (_c *UserAttributeValueCreate) SaveX(ctx context.Context) *UserAttributeValue {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *UserAttributeValueCreate) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *UserAttributeValueCreate) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_c *UserAttributeValueCreate) defaults() {
if _, ok := _c.mutation.CreatedAt(); !ok {
v := userattributevalue.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
v := userattributevalue.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
if _, ok := _c.mutation.Value(); !ok {
v := userattributevalue.DefaultValue
_c.mutation.SetValue(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_c *UserAttributeValueCreate) check() error {
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UserAttributeValue.created_at"`)}
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "UserAttributeValue.updated_at"`)}
}
if _, ok := _c.mutation.UserID(); !ok {
return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "UserAttributeValue.user_id"`)}
}
if _, ok := _c.mutation.AttributeID(); !ok {
return &ValidationError{Name: "attribute_id", err: errors.New(`ent: missing required field "UserAttributeValue.attribute_id"`)}
}
if _, ok := _c.mutation.Value(); !ok {
return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "UserAttributeValue.value"`)}
}
if len(_c.mutation.UserIDs()) == 0 {
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "UserAttributeValue.user"`)}
}
if len(_c.mutation.DefinitionIDs()) == 0 {
return &ValidationError{Name: "definition", err: errors.New(`ent: missing required edge "UserAttributeValue.definition"`)}
}
return nil
}
func (_c *UserAttributeValueCreate) sqlSave(ctx context.Context) (*UserAttributeValue, error) {
if err := _c.check(); err != nil {
return nil, err
}
_node, _spec := _c.createSpec()
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int64(id)
_c.mutation.id = &_node.ID
_c.mutation.done = true
return _node, nil
}
func (_c *UserAttributeValueCreate) createSpec() (*UserAttributeValue, *sqlgraph.CreateSpec) {
var (
_node = &UserAttributeValue{config: _c.config}
_spec = sqlgraph.NewCreateSpec(userattributevalue.Table, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(userattributevalue.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if value, ok := _c.mutation.UpdatedAt(); ok {
_spec.SetField(userattributevalue.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = value
}
if value, ok := _c.mutation.Value(); ok {
_spec.SetField(userattributevalue.FieldValue, field.TypeString, value)
_node.Value = value
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.UserTable,
Columns: []string{userattributevalue.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.UserID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.DefinitionIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.DefinitionTable,
Columns: []string{userattributevalue.DefinitionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.AttributeID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.UserAttributeValue.Create().
// SetCreatedAt(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.UserAttributeValueUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *UserAttributeValueCreate) OnConflict(opts ...sql.ConflictOption) *UserAttributeValueUpsertOne {
_c.conflict = opts
return &UserAttributeValueUpsertOne{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *UserAttributeValueCreate) OnConflictColumns(columns ...string) *UserAttributeValueUpsertOne {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &UserAttributeValueUpsertOne{
create: _c,
}
}
type (
// UserAttributeValueUpsertOne is the builder for "upsert"-ing
// one UserAttributeValue node.
UserAttributeValueUpsertOne struct {
create *UserAttributeValueCreate
}
// UserAttributeValueUpsert is the "OnConflict" setter.
UserAttributeValueUpsert struct {
*sql.UpdateSet
}
)
// SetUpdatedAt sets the "updated_at" field.
func (u *UserAttributeValueUpsert) SetUpdatedAt(v time.Time) *UserAttributeValueUpsert {
u.Set(userattributevalue.FieldUpdatedAt, v)
return u
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *UserAttributeValueUpsert) UpdateUpdatedAt() *UserAttributeValueUpsert {
u.SetExcluded(userattributevalue.FieldUpdatedAt)
return u
}
// SetUserID sets the "user_id" field.
func (u *UserAttributeValueUpsert) SetUserID(v int64) *UserAttributeValueUpsert {
u.Set(userattributevalue.FieldUserID, v)
return u
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsert) UpdateUserID() *UserAttributeValueUpsert {
u.SetExcluded(userattributevalue.FieldUserID)
return u
}
// SetAttributeID sets the "attribute_id" field.
func (u *UserAttributeValueUpsert) SetAttributeID(v int64) *UserAttributeValueUpsert {
u.Set(userattributevalue.FieldAttributeID, v)
return u
}
// UpdateAttributeID sets the "attribute_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsert) UpdateAttributeID() *UserAttributeValueUpsert {
u.SetExcluded(userattributevalue.FieldAttributeID)
return u
}
// SetValue sets the "value" field.
func (u *UserAttributeValueUpsert) SetValue(v string) *UserAttributeValueUpsert {
u.Set(userattributevalue.FieldValue, v)
return u
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *UserAttributeValueUpsert) UpdateValue() *UserAttributeValueUpsert {
u.SetExcluded(userattributevalue.FieldValue)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *UserAttributeValueUpsertOne) UpdateNewValues() *UserAttributeValueUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
if _, exists := u.create.mutation.CreatedAt(); exists {
s.SetIgnore(userattributevalue.FieldCreatedAt)
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *UserAttributeValueUpsertOne) Ignore() *UserAttributeValueUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *UserAttributeValueUpsertOne) DoNothing() *UserAttributeValueUpsertOne {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the UserAttributeValueCreate.OnConflict
// documentation for more info.
func (u *UserAttributeValueUpsertOne) Update(set func(*UserAttributeValueUpsert)) *UserAttributeValueUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&UserAttributeValueUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *UserAttributeValueUpsertOne) SetUpdatedAt(v time.Time) *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *UserAttributeValueUpsertOne) UpdateUpdatedAt() *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateUpdatedAt()
})
}
// SetUserID sets the "user_id" field.
func (u *UserAttributeValueUpsertOne) SetUserID(v int64) *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetUserID(v)
})
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsertOne) UpdateUserID() *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateUserID()
})
}
// SetAttributeID sets the "attribute_id" field.
func (u *UserAttributeValueUpsertOne) SetAttributeID(v int64) *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetAttributeID(v)
})
}
// UpdateAttributeID sets the "attribute_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsertOne) UpdateAttributeID() *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateAttributeID()
})
}
// SetValue sets the "value" field.
func (u *UserAttributeValueUpsertOne) SetValue(v string) *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetValue(v)
})
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *UserAttributeValueUpsertOne) UpdateValue() *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateValue()
})
}
// Exec executes the query.
func (u *UserAttributeValueUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for UserAttributeValueCreate.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *UserAttributeValueUpsertOne) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}
// Exec executes the UPSERT query and returns the inserted/updated ID.
func (u *UserAttributeValueUpsertOne) ID(ctx context.Context) (id int64, err error) {
node, err := u.create.Save(ctx)
if err != nil {
return id, err
}
return node.ID, nil
}
// IDX is like ID, but panics if an error occurs.
func (u *UserAttributeValueUpsertOne) IDX(ctx context.Context) int64 {
id, err := u.ID(ctx)
if err != nil {
panic(err)
}
return id
}
// UserAttributeValueCreateBulk is the builder for creating many UserAttributeValue entities in bulk.
type UserAttributeValueCreateBulk struct {
config
err error
builders []*UserAttributeValueCreate
conflict []sql.ConflictOption
}
// Save creates the UserAttributeValue entities in the database.
func (_c *UserAttributeValueCreateBulk) Save(ctx context.Context) ([]*UserAttributeValue, error) {
if _c.err != nil {
return nil, _c.err
}
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
nodes := make([]*UserAttributeValue, len(_c.builders))
mutators := make([]Mutator, len(_c.builders))
for i := range _c.builders {
func(i int, root context.Context) {
builder := _c.builders[i]
builder.defaults()
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*UserAttributeValueMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
spec.OnConflict = _c.conflict
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int64(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (_c *UserAttributeValueCreateBulk) SaveX(ctx context.Context) []*UserAttributeValue {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *UserAttributeValueCreateBulk) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *UserAttributeValueCreateBulk) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.UserAttributeValue.CreateBulk(builders...).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.UserAttributeValueUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *UserAttributeValueCreateBulk) OnConflict(opts ...sql.ConflictOption) *UserAttributeValueUpsertBulk {
_c.conflict = opts
return &UserAttributeValueUpsertBulk{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *UserAttributeValueCreateBulk) OnConflictColumns(columns ...string) *UserAttributeValueUpsertBulk {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &UserAttributeValueUpsertBulk{
create: _c,
}
}
// UserAttributeValueUpsertBulk is the builder for "upsert"-ing
// a bulk of UserAttributeValue nodes.
type UserAttributeValueUpsertBulk struct {
create *UserAttributeValueCreateBulk
}
// UpdateNewValues updates the mutable fields using the new values that
// were set on create. Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *UserAttributeValueUpsertBulk) UpdateNewValues() *UserAttributeValueUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
for _, b := range u.create.builders {
if _, exists := b.mutation.CreatedAt(); exists {
s.SetIgnore(userattributevalue.FieldCreatedAt)
}
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *UserAttributeValueUpsertBulk) Ignore() *UserAttributeValueUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *UserAttributeValueUpsertBulk) DoNothing() *UserAttributeValueUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the UserAttributeValueCreateBulk.OnConflict
// documentation for more info.
func (u *UserAttributeValueUpsertBulk) Update(set func(*UserAttributeValueUpsert)) *UserAttributeValueUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&UserAttributeValueUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *UserAttributeValueUpsertBulk) SetUpdatedAt(v time.Time) *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *UserAttributeValueUpsertBulk) UpdateUpdatedAt() *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateUpdatedAt()
})
}
// SetUserID sets the "user_id" field.
func (u *UserAttributeValueUpsertBulk) SetUserID(v int64) *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetUserID(v)
})
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsertBulk) UpdateUserID() *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateUserID()
})
}
// SetAttributeID sets the "attribute_id" field.
func (u *UserAttributeValueUpsertBulk) SetAttributeID(v int64) *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetAttributeID(v)
})
}
// UpdateAttributeID sets the "attribute_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsertBulk) UpdateAttributeID() *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateAttributeID()
})
}
// SetValue sets the "value" field.
func (u *UserAttributeValueUpsertBulk) SetValue(v string) *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetValue(v)
})
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *UserAttributeValueUpsertBulk) UpdateValue() *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateValue()
})
}
// Exec executes the query.
func (u *UserAttributeValueUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
return u.create.err
}
for i, b := range u.create.builders {
if len(b.conflict) != 0 {
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UserAttributeValueCreateBulk instead", i)
}
}
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for UserAttributeValueCreateBulk.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *UserAttributeValueUpsertBulk) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeValueDelete is the builder for deleting a UserAttributeValue entity.
type UserAttributeValueDelete struct {
config
hooks []Hook
mutation *UserAttributeValueMutation
}
// Where appends a list predicates to the UserAttributeValueDelete builder.
func (_d *UserAttributeValueDelete) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UserAttributeValueDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserAttributeValueDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UserAttributeValueDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(userattributevalue.Table, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UserAttributeValueDeleteOne is the builder for deleting a single UserAttributeValue entity.
type UserAttributeValueDeleteOne struct {
_d *UserAttributeValueDelete
}
// Where appends a list predicates to the UserAttributeValueDelete builder.
func (_d *UserAttributeValueDeleteOne) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UserAttributeValueDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{userattributevalue.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserAttributeValueDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,681 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeValueQuery is the builder for querying UserAttributeValue entities.
type UserAttributeValueQuery struct {
config
ctx *QueryContext
order []userattributevalue.OrderOption
inters []Interceptor
predicates []predicate.UserAttributeValue
withUser *UserQuery
withDefinition *UserAttributeDefinitionQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UserAttributeValueQuery builder.
func (_q *UserAttributeValueQuery) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UserAttributeValueQuery) Limit(limit int) *UserAttributeValueQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UserAttributeValueQuery) Offset(offset int) *UserAttributeValueQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UserAttributeValueQuery) Unique(unique bool) *UserAttributeValueQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UserAttributeValueQuery) Order(o ...userattributevalue.OrderOption) *UserAttributeValueQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryUser chains the current query on the "user" edge.
func (_q *UserAttributeValueQuery) QueryUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.UserTable, userattributevalue.UserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryDefinition chains the current query on the "definition" edge.
func (_q *UserAttributeValueQuery) QueryDefinition() *UserAttributeDefinitionQuery {
query := (&UserAttributeDefinitionClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, selector),
sqlgraph.To(userattributedefinition.Table, userattributedefinition.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.DefinitionTable, userattributevalue.DefinitionColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserAttributeValue entity from the query.
// Returns a *NotFoundError when no UserAttributeValue was found.
func (_q *UserAttributeValueQuery) First(ctx context.Context) (*UserAttributeValue, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{userattributevalue.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UserAttributeValueQuery) FirstX(ctx context.Context) *UserAttributeValue {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UserAttributeValue ID from the query.
// Returns a *NotFoundError when no UserAttributeValue ID was found.
func (_q *UserAttributeValueQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{userattributevalue.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UserAttributeValueQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UserAttributeValue entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UserAttributeValue entity is found.
// Returns a *NotFoundError when no UserAttributeValue entities are found.
func (_q *UserAttributeValueQuery) Only(ctx context.Context) (*UserAttributeValue, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{userattributevalue.Label}
default:
return nil, &NotSingularError{userattributevalue.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UserAttributeValueQuery) OnlyX(ctx context.Context) *UserAttributeValue {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UserAttributeValue ID in the query.
// Returns a *NotSingularError when more than one UserAttributeValue ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UserAttributeValueQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{userattributevalue.Label}
default:
err = &NotSingularError{userattributevalue.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UserAttributeValueQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UserAttributeValues.
func (_q *UserAttributeValueQuery) All(ctx context.Context) ([]*UserAttributeValue, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UserAttributeValue, *UserAttributeValueQuery]()
return withInterceptors[[]*UserAttributeValue](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UserAttributeValueQuery) AllX(ctx context.Context) []*UserAttributeValue {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UserAttributeValue IDs.
func (_q *UserAttributeValueQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(userattributevalue.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UserAttributeValueQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UserAttributeValueQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UserAttributeValueQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UserAttributeValueQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UserAttributeValueQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UserAttributeValueQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UserAttributeValueQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UserAttributeValueQuery) Clone() *UserAttributeValueQuery {
if _q == nil {
return nil
}
return &UserAttributeValueQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]userattributevalue.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UserAttributeValue{}, _q.predicates...),
withUser: _q.withUser.Clone(),
withDefinition: _q.withDefinition.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithUser tells the query-builder to eager-load the nodes that are connected to
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserAttributeValueQuery) WithUser(opts ...func(*UserQuery)) *UserAttributeValueQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUser = query
return _q
}
// WithDefinition tells the query-builder to eager-load the nodes that are connected to
// the "definition" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserAttributeValueQuery) WithDefinition(opts ...func(*UserAttributeDefinitionQuery)) *UserAttributeValueQuery {
query := (&UserAttributeDefinitionClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withDefinition = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UserAttributeValue.Query().
// GroupBy(userattributevalue.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UserAttributeValueQuery) GroupBy(field string, fields ...string) *UserAttributeValueGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UserAttributeValueGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = userattributevalue.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.UserAttributeValue.Query().
// Select(userattributevalue.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *UserAttributeValueQuery) Select(fields ...string) *UserAttributeValueSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UserAttributeValueSelect{UserAttributeValueQuery: _q}
sbuild.label = userattributevalue.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UserAttributeValueSelect configured with the given aggregations.
func (_q *UserAttributeValueQuery) Aggregate(fns ...AggregateFunc) *UserAttributeValueSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UserAttributeValueQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !userattributevalue.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserAttributeValue, error) {
var (
nodes = []*UserAttributeValue{}
_spec = _q.querySpec()
loadedTypes = [2]bool{
_q.withUser != nil,
_q.withDefinition != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UserAttributeValue).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UserAttributeValue{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withUser; query != nil {
if err := _q.loadUser(ctx, query, nodes, nil,
func(n *UserAttributeValue, e *User) { n.Edges.User = e }); err != nil {
return nil, err
}
}
if query := _q.withDefinition; query != nil {
if err := _q.loadDefinition(ctx, query, nodes, nil,
func(n *UserAttributeValue, e *UserAttributeDefinition) { n.Edges.Definition = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *UserAttributeValueQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserAttributeValue, init func(*UserAttributeValue), assign func(*UserAttributeValue, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserAttributeValue)
for i := range nodes {
fk := nodes[i].UserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *UserAttributeDefinitionQuery, nodes []*UserAttributeValue, init func(*UserAttributeValue), assign func(*UserAttributeValue, *UserAttributeDefinition)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserAttributeValue)
for i := range nodes {
fk := nodes[i].AttributeID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(userattributedefinition.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "attribute_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserAttributeValueQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UserAttributeValueQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(userattributevalue.Table, userattributevalue.Columns, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userattributevalue.FieldID)
for i := range fields {
if fields[i] != userattributevalue.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withUser != nil {
_spec.Node.AddColumnOnce(userattributevalue.FieldUserID)
}
if _q.withDefinition != nil {
_spec.Node.AddColumnOnce(userattributevalue.FieldAttributeID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(userattributevalue.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = userattributevalue.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities.
type UserAttributeValueGroupBy struct {
selector
build *UserAttributeValueQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UserAttributeValueGroupBy) Aggregate(fns ...AggregateFunc) *UserAttributeValueGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UserAttributeValueGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserAttributeValueQuery, *UserAttributeValueGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UserAttributeValueGroupBy) sqlScan(ctx context.Context, root *UserAttributeValueQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UserAttributeValueSelect is the builder for selecting fields of UserAttributeValue entities.
type UserAttributeValueSelect struct {
*UserAttributeValueQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UserAttributeValueSelect) Aggregate(fns ...AggregateFunc) *UserAttributeValueSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UserAttributeValueSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserAttributeValueQuery, *UserAttributeValueSelect](ctx, _s.UserAttributeValueQuery, _s, _s.inters, v)
}
func (_s *UserAttributeValueSelect) sqlScan(ctx context.Context, root *UserAttributeValueQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,504 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeValueUpdate is the builder for updating UserAttributeValue entities.
type UserAttributeValueUpdate struct {
config
hooks []Hook
mutation *UserAttributeValueMutation
}
// Where appends a list predicates to the UserAttributeValueUpdate builder.
func (_u *UserAttributeValueUpdate) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserAttributeValueUpdate) SetUpdatedAt(v time.Time) *UserAttributeValueUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserAttributeValueUpdate) SetUserID(v int64) *UserAttributeValueUpdate {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserAttributeValueUpdate) SetNillableUserID(v *int64) *UserAttributeValueUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetAttributeID sets the "attribute_id" field.
func (_u *UserAttributeValueUpdate) SetAttributeID(v int64) *UserAttributeValueUpdate {
_u.mutation.SetAttributeID(v)
return _u
}
// SetNillableAttributeID sets the "attribute_id" field if the given value is not nil.
func (_u *UserAttributeValueUpdate) SetNillableAttributeID(v *int64) *UserAttributeValueUpdate {
if v != nil {
_u.SetAttributeID(*v)
}
return _u
}
// SetValue sets the "value" field.
func (_u *UserAttributeValueUpdate) SetValue(v string) *UserAttributeValueUpdate {
_u.mutation.SetValue(v)
return _u
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_u *UserAttributeValueUpdate) SetNillableValue(v *string) *UserAttributeValueUpdate {
if v != nil {
_u.SetValue(*v)
}
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserAttributeValueUpdate) SetUser(v *User) *UserAttributeValueUpdate {
return _u.SetUserID(v.ID)
}
// SetDefinitionID sets the "definition" edge to the UserAttributeDefinition entity by ID.
func (_u *UserAttributeValueUpdate) SetDefinitionID(id int64) *UserAttributeValueUpdate {
_u.mutation.SetDefinitionID(id)
return _u
}
// SetDefinition sets the "definition" edge to the UserAttributeDefinition entity.
func (_u *UserAttributeValueUpdate) SetDefinition(v *UserAttributeDefinition) *UserAttributeValueUpdate {
return _u.SetDefinitionID(v.ID)
}
// Mutation returns the UserAttributeValueMutation object of the builder.
func (_u *UserAttributeValueUpdate) Mutation() *UserAttributeValueMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserAttributeValueUpdate) ClearUser() *UserAttributeValueUpdate {
_u.mutation.ClearUser()
return _u
}
// ClearDefinition clears the "definition" edge to the UserAttributeDefinition entity.
func (_u *UserAttributeValueUpdate) ClearDefinition() *UserAttributeValueUpdate {
_u.mutation.ClearDefinition()
return _u
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserAttributeValueUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserAttributeValueUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *UserAttributeValueUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserAttributeValueUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserAttributeValueUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := userattributevalue.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserAttributeValueUpdate) check() error {
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.user"`)
}
if _u.mutation.DefinitionCleared() && len(_u.mutation.DefinitionIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.definition"`)
}
return nil
}
func (_u *UserAttributeValueUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userattributevalue.Table, userattributevalue.Columns, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userattributevalue.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Value(); ok {
_spec.SetField(userattributevalue.FieldValue, field.TypeString, value)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.UserTable,
Columns: []string{userattributevalue.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.UserTable,
Columns: []string{userattributevalue.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.DefinitionCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.DefinitionTable,
Columns: []string{userattributevalue.DefinitionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.DefinitionIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.DefinitionTable,
Columns: []string{userattributevalue.DefinitionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userattributevalue.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// UserAttributeValueUpdateOne is the builder for updating a single UserAttributeValue entity.
type UserAttributeValueUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UserAttributeValueMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserAttributeValueUpdateOne) SetUpdatedAt(v time.Time) *UserAttributeValueUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserAttributeValueUpdateOne) SetUserID(v int64) *UserAttributeValueUpdateOne {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserAttributeValueUpdateOne) SetNillableUserID(v *int64) *UserAttributeValueUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetAttributeID sets the "attribute_id" field.
func (_u *UserAttributeValueUpdateOne) SetAttributeID(v int64) *UserAttributeValueUpdateOne {
_u.mutation.SetAttributeID(v)
return _u
}
// SetNillableAttributeID sets the "attribute_id" field if the given value is not nil.
func (_u *UserAttributeValueUpdateOne) SetNillableAttributeID(v *int64) *UserAttributeValueUpdateOne {
if v != nil {
_u.SetAttributeID(*v)
}
return _u
}
// SetValue sets the "value" field.
func (_u *UserAttributeValueUpdateOne) SetValue(v string) *UserAttributeValueUpdateOne {
_u.mutation.SetValue(v)
return _u
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_u *UserAttributeValueUpdateOne) SetNillableValue(v *string) *UserAttributeValueUpdateOne {
if v != nil {
_u.SetValue(*v)
}
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserAttributeValueUpdateOne) SetUser(v *User) *UserAttributeValueUpdateOne {
return _u.SetUserID(v.ID)
}
// SetDefinitionID sets the "definition" edge to the UserAttributeDefinition entity by ID.
func (_u *UserAttributeValueUpdateOne) SetDefinitionID(id int64) *UserAttributeValueUpdateOne {
_u.mutation.SetDefinitionID(id)
return _u
}
// SetDefinition sets the "definition" edge to the UserAttributeDefinition entity.
func (_u *UserAttributeValueUpdateOne) SetDefinition(v *UserAttributeDefinition) *UserAttributeValueUpdateOne {
return _u.SetDefinitionID(v.ID)
}
// Mutation returns the UserAttributeValueMutation object of the builder.
func (_u *UserAttributeValueUpdateOne) Mutation() *UserAttributeValueMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserAttributeValueUpdateOne) ClearUser() *UserAttributeValueUpdateOne {
_u.mutation.ClearUser()
return _u
}
// ClearDefinition clears the "definition" edge to the UserAttributeDefinition entity.
func (_u *UserAttributeValueUpdateOne) ClearDefinition() *UserAttributeValueUpdateOne {
_u.mutation.ClearDefinition()
return _u
}
// Where appends a list predicates to the UserAttributeValueUpdate builder.
func (_u *UserAttributeValueUpdateOne) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *UserAttributeValueUpdateOne) Select(field string, fields ...string) *UserAttributeValueUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated UserAttributeValue entity.
func (_u *UserAttributeValueUpdateOne) Save(ctx context.Context) (*UserAttributeValue, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserAttributeValueUpdateOne) SaveX(ctx context.Context) *UserAttributeValue {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *UserAttributeValueUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserAttributeValueUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserAttributeValueUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := userattributevalue.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserAttributeValueUpdateOne) check() error {
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.user"`)
}
if _u.mutation.DefinitionCleared() && len(_u.mutation.DefinitionIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.definition"`)
}
return nil
}
func (_u *UserAttributeValueUpdateOne) sqlSave(ctx context.Context) (_node *UserAttributeValue, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userattributevalue.Table, userattributevalue.Columns, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserAttributeValue.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userattributevalue.FieldID)
for _, f := range fields {
if !userattributevalue.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != userattributevalue.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userattributevalue.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Value(); ok {
_spec.SetField(userattributevalue.FieldValue, field.TypeString, value)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.UserTable,
Columns: []string{userattributevalue.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.UserTable,
Columns: []string{userattributevalue.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.DefinitionCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.DefinitionTable,
Columns: []string{userattributevalue.DefinitionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.DefinitionIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.DefinitionTable,
Columns: []string{userattributevalue.DefinitionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserAttributeValue{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userattributevalue.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -3,6 +3,7 @@ package config
import (
"fmt"
"strings"
"time"
"github.com/spf13/viper"
)
@@ -43,6 +44,7 @@ type Config struct {
type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
Quota GeminiQuotaConfig `mapstructure:"quota"`
}
type GeminiOAuthConfig struct {
@@ -51,6 +53,17 @@ type GeminiOAuthConfig struct {
Scopes string `mapstructure:"scopes"`
}
type GeminiQuotaConfig struct {
Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"`
Policy string `mapstructure:"policy"`
}
type GeminiTierQuotaConfig struct {
ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"`
FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"`
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
@@ -119,6 +132,37 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
type GatewaySchedulingConfig struct {
// 粘性会话排队配置
StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"`
StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"`
// 兜底排队配置
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
// 过期槽位清理周期0 表示禁用)
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
}
func (s *ServerConfig) Address() string {
@@ -313,6 +357,10 @@ func setDefaults() {
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
@@ -323,6 +371,12 @@ func setDefaults() {
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
@@ -337,6 +391,7 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
}
func (c *Config) Validate() error {
@@ -411,6 +466,21 @@ func (c *Config) Validate() error {
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
}
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
return nil
}

View File

@@ -1,6 +1,11 @@
package config
import "testing"
import (
"testing"
"time"
"github.com/spf13/viper"
)
func TestNormalizeRunMode(t *testing.T) {
tests := []struct {
@@ -21,3 +26,45 @@ func TestNormalizeRunMode(t *testing.T) {
}
}
}
func TestLoadDefaultSchedulingConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 {
t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting)
}
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
t.Fatalf("LoadBatchEnabled = false, want true")
}
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
}
}
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
viper.Reset()
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 {
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
}

View File

@@ -3,6 +3,7 @@ package admin
import (
"strconv"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
@@ -13,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
)
// OAuthHandler handles OAuth-related operations for accounts
@@ -989,3 +991,164 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
response.Success(c, models)
}
// RefreshTier handles refreshing Google One tier for a single account
// POST /api/v1/admin/accounts/:id/refresh-tier
func (h *AccountHandler) RefreshTier(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
ctx := c.Request.Context()
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
if account.Platform != service.PlatformGemini || account.Type != service.AccountTypeOAuth {
response.BadRequest(c, "Only Gemini OAuth accounts support tier refresh")
return
}
oauthType, _ := account.Credentials["oauth_type"].(string)
if oauthType != "google_one" {
response.BadRequest(c, "Only google_one OAuth accounts support tier refresh")
return
}
tierID, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(ctx, account)
if err != nil {
response.ErrorFrom(c, err)
return
}
_, updateErr := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{
Credentials: creds,
Extra: extra,
})
if updateErr != nil {
response.ErrorFrom(c, updateErr)
return
}
response.Success(c, gin.H{
"tier_id": tierID,
"storage_info": extra,
"drive_storage_limit": extra["drive_storage_limit"],
"drive_storage_usage": extra["drive_storage_usage"],
"updated_at": extra["drive_tier_updated_at"],
})
}
// BatchRefreshTierRequest represents batch tier refresh request
type BatchRefreshTierRequest struct {
AccountIDs []int64 `json:"account_ids"`
}
// BatchRefreshTier handles batch refreshing Google One tier
// POST /api/v1/admin/accounts/batch-refresh-tier
func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
var req BatchRefreshTierRequest
if err := c.ShouldBindJSON(&req); err != nil {
req = BatchRefreshTierRequest{}
}
ctx := c.Request.Context()
accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
if err != nil {
response.ErrorFrom(c, err)
return
}
for i := range allAccounts {
acc := &allAccounts[i]
oauthType, _ := acc.Credentials["oauth_type"].(string)
if oauthType == "google_one" {
accounts = append(accounts, acc)
}
}
} else {
fetched, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
for _, acc := range fetched {
if acc == nil {
continue
}
if acc.Platform != service.PlatformGemini || acc.Type != service.AccountTypeOAuth {
continue
}
oauthType, _ := acc.Credentials["oauth_type"].(string)
if oauthType != "google_one" {
continue
}
accounts = append(accounts, acc)
}
}
const maxConcurrency = 10
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(maxConcurrency)
var mu sync.Mutex
var successCount, failedCount int
var errors []gin.H
for _, account := range accounts {
acc := account // 闭包捕获
g.Go(func() error {
_, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(gctx, acc)
if err != nil {
mu.Lock()
failedCount++
errors = append(errors, gin.H{
"account_id": acc.ID,
"error": err.Error(),
})
mu.Unlock()
return nil
}
_, updateErr := h.adminService.UpdateAccount(gctx, acc.ID, &service.UpdateAccountInput{
Credentials: creds,
Extra: extra,
})
mu.Lock()
if updateErr != nil {
failedCount++
errors = append(errors, gin.H{
"account_id": acc.ID,
"error": updateErr.Error(),
})
} else {
successCount++
}
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
results := gin.H{
"total": len(accounts),
"success": successCount,
"failed": failedCount,
"errors": errors,
}
response.Success(c, results)
}

View File

@@ -46,8 +46,8 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
return
}
@@ -92,8 +92,8 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
return
}

View File

@@ -246,7 +246,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
}
// Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -0,0 +1,342 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserAttributeHandler handles user attribute management
type UserAttributeHandler struct {
attrService *service.UserAttributeService
}
// NewUserAttributeHandler creates a new handler
func NewUserAttributeHandler(attrService *service.UserAttributeService) *UserAttributeHandler {
return &UserAttributeHandler{attrService: attrService}
}
// --- Request/Response DTOs ---
// CreateAttributeDefinitionRequest represents create attribute definition request
type CreateAttributeDefinitionRequest struct {
Key string `json:"key" binding:"required,min=1,max=100"`
Name string `json:"name" binding:"required,min=1,max=255"`
Description string `json:"description"`
Type string `json:"type" binding:"required"`
Options []service.UserAttributeOption `json:"options"`
Required bool `json:"required"`
Validation service.UserAttributeValidation `json:"validation"`
Placeholder string `json:"placeholder"`
Enabled bool `json:"enabled"`
}
// UpdateAttributeDefinitionRequest represents update attribute definition request
type UpdateAttributeDefinitionRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
Type *string `json:"type"`
Options *[]service.UserAttributeOption `json:"options"`
Required *bool `json:"required"`
Validation *service.UserAttributeValidation `json:"validation"`
Placeholder *string `json:"placeholder"`
Enabled *bool `json:"enabled"`
}
// ReorderRequest represents reorder attribute definitions request
type ReorderRequest struct {
IDs []int64 `json:"ids" binding:"required"`
}
// UpdateUserAttributesRequest represents update user attributes request
type UpdateUserAttributesRequest struct {
Values map[int64]string `json:"values" binding:"required"`
}
// BatchGetUserAttributesRequest represents batch get user attributes request
type BatchGetUserAttributesRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
// BatchUserAttributesResponse represents batch user attributes response
type BatchUserAttributesResponse struct {
// Map of userID -> map of attributeID -> value
Attributes map[int64]map[int64]string `json:"attributes"`
}
// AttributeDefinitionResponse represents attribute definition response
type AttributeDefinitionResponse struct {
ID int64 `json:"id"`
Key string `json:"key"`
Name string `json:"name"`
Description string `json:"description"`
Type string `json:"type"`
Options []service.UserAttributeOption `json:"options"`
Required bool `json:"required"`
Validation service.UserAttributeValidation `json:"validation"`
Placeholder string `json:"placeholder"`
DisplayOrder int `json:"display_order"`
Enabled bool `json:"enabled"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// AttributeValueResponse represents attribute value response
type AttributeValueResponse struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
AttributeID int64 `json:"attribute_id"`
Value string `json:"value"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// --- Helpers ---
func defToResponse(def *service.UserAttributeDefinition) *AttributeDefinitionResponse {
return &AttributeDefinitionResponse{
ID: def.ID,
Key: def.Key,
Name: def.Name,
Description: def.Description,
Type: string(def.Type),
Options: def.Options,
Required: def.Required,
Validation: def.Validation,
Placeholder: def.Placeholder,
DisplayOrder: def.DisplayOrder,
Enabled: def.Enabled,
CreatedAt: def.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
UpdatedAt: def.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
}
func valueToResponse(val *service.UserAttributeValue) *AttributeValueResponse {
return &AttributeValueResponse{
ID: val.ID,
UserID: val.UserID,
AttributeID: val.AttributeID,
Value: val.Value,
CreatedAt: val.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
UpdatedAt: val.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
}
// --- Handlers ---
// ListDefinitions lists all attribute definitions
// GET /admin/user-attributes
func (h *UserAttributeHandler) ListDefinitions(c *gin.Context) {
enabledOnly := c.Query("enabled") == "true"
defs, err := h.attrService.ListDefinitions(c.Request.Context(), enabledOnly)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeDefinitionResponse, 0, len(defs))
for i := range defs {
out = append(out, defToResponse(&defs[i]))
}
response.Success(c, out)
}
// CreateDefinition creates a new attribute definition
// POST /admin/user-attributes
func (h *UserAttributeHandler) CreateDefinition(c *gin.Context) {
var req CreateAttributeDefinitionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
def, err := h.attrService.CreateDefinition(c.Request.Context(), service.CreateAttributeDefinitionInput{
Key: req.Key,
Name: req.Name,
Description: req.Description,
Type: service.UserAttributeType(req.Type),
Options: req.Options,
Required: req.Required,
Validation: req.Validation,
Placeholder: req.Placeholder,
Enabled: req.Enabled,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, defToResponse(def))
}
// UpdateDefinition updates an attribute definition
// PUT /admin/user-attributes/:id
func (h *UserAttributeHandler) UpdateDefinition(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid attribute ID")
return
}
var req UpdateAttributeDefinitionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := service.UpdateAttributeDefinitionInput{
Name: req.Name,
Description: req.Description,
Options: req.Options,
Required: req.Required,
Validation: req.Validation,
Placeholder: req.Placeholder,
Enabled: req.Enabled,
}
if req.Type != nil {
t := service.UserAttributeType(*req.Type)
input.Type = &t
}
def, err := h.attrService.UpdateDefinition(c.Request.Context(), id, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, defToResponse(def))
}
// DeleteDefinition deletes an attribute definition
// DELETE /admin/user-attributes/:id
func (h *UserAttributeHandler) DeleteDefinition(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid attribute ID")
return
}
if err := h.attrService.DeleteDefinition(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Attribute definition deleted successfully"})
}
// ReorderDefinitions reorders attribute definitions
// PUT /admin/user-attributes/reorder
func (h *UserAttributeHandler) ReorderDefinitions(c *gin.Context) {
var req ReorderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Convert IDs array to orders map (position in array = display_order)
orders := make(map[int64]int, len(req.IDs))
for i, id := range req.IDs {
orders[id] = i
}
if err := h.attrService.ReorderDefinitions(c.Request.Context(), orders); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Reorder successful"})
}
// GetUserAttributes gets a user's attribute values
// GET /admin/users/:id/attributes
func (h *UserAttributeHandler) GetUserAttributes(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeValueResponse, 0, len(values))
for i := range values {
out = append(out, valueToResponse(&values[i]))
}
response.Success(c, out)
}
// UpdateUserAttributes updates a user's attribute values
// PUT /admin/users/:id/attributes
func (h *UserAttributeHandler) UpdateUserAttributes(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateUserAttributesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
inputs := make([]service.UpdateUserAttributeInput, 0, len(req.Values))
for attrID, value := range req.Values {
inputs = append(inputs, service.UpdateUserAttributeInput{
AttributeID: attrID,
Value: value,
})
}
if err := h.attrService.UpdateUserAttributes(c.Request.Context(), userID, inputs); err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated values
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeValueResponse, 0, len(values))
for i := range values {
out = append(out, valueToResponse(&values[i]))
}
response.Success(c, out)
}
// GetBatchUserAttributes gets attribute values for multiple users
// POST /admin/user-attributes/batch
func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
var req BatchGetUserAttributesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.UserIDs) == 0 {
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
return
}
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
}

View File

@@ -27,7 +27,6 @@ type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
@@ -40,7 +39,6 @@ type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"`
Username *string `json:"username"`
Wechat *string `json:"wechat"`
Notes *string `json:"notes"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
@@ -57,13 +55,22 @@ type UpdateBalanceRequest struct {
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
// - status: filter by user status
// - role: filter by user role
// - search: search in email, username
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
role := c.Query("role")
search := c.Query("search")
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
filters := service.UserListFilters{
Status: c.Query("status"),
Role: c.Query("role"),
Search: c.Query("search"),
Attributes: parseAttributeFilters(c),
}
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -76,6 +83,29 @@ func (h *UserHandler) List(c *gin.Context) {
response.Paginated(c, out, total, page, pageSize)
}
// parseAttributeFilters extracts attribute filters from query params
// Format: attr[{attributeID}]=value, e.g. attr[1]=company&attr[2]=developer
func parseAttributeFilters(c *gin.Context) map[int64]string {
result := make(map[int64]string)
// Get all query params and look for attr[*] pattern
for key, values := range c.Request.URL.Query() {
if len(values) == 0 || values[0] == "" {
continue
}
// Check if key matches pattern attr[{id}]
if len(key) > 5 && key[:5] == "attr[" && key[len(key)-1] == ']' {
idStr := key[5 : len(key)-1]
id, err := strconv.ParseInt(idStr, 10, 64)
if err == nil && id > 0 {
result[id] = values[0]
}
}
}
return result
}
// GetByID handles getting a user by ID
// GET /api/v1/admin/users/:id
func (h *UserHandler) GetByID(c *gin.Context) {
@@ -107,7 +137,6 @@ func (h *UserHandler) Create(c *gin.Context) {
Email: req.Email,
Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
@@ -141,7 +170,6 @@ func (h *UserHandler) Update(c *gin.Context) {
Email: req.Email,
Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,

View File

@@ -10,7 +10,6 @@ func UserFromServiceShallow(u *service.User) *User {
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,

View File

@@ -6,7 +6,6 @@ type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`

View File

@@ -142,6 +142,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
sessionKey := sessionHash
if platform == service.PlatformGemini && sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
if platform == service.PlatformGemini {
const maxAccountSwitches = 3
@@ -150,7 +154,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -159,9 +163,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
@@ -171,11 +179,46 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流
@@ -188,6 +231,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -232,7 +278,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -241,9 +287,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
@@ -253,11 +303,46 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流
@@ -270,6 +355,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -309,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Models handles listing available models
// GET /v1/models
// Returns different model lists based on the API key's group platform or forced platform
// Returns models based on account configurations (model_mapping whitelist)
// Falls back to default models if no whitelist is configured
func (h *GatewayHandler) Models(c *gin.Context) {
apiKey, _ := middleware2.GetApiKeyFromContext(c)
@@ -324,8 +413,37 @@ func (h *GatewayHandler) Models(c *gin.Context) {
}
}
// Return OpenAI models for OpenAI platform groups
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
var groupID *int64
var platform string
if apiKey != nil && apiKey.Group != nil {
groupID = &apiKey.Group.ID
platform = apiKey.Group.Platform
}
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
if len(availableModels) > 0 {
// Build model list from whitelist
models := make([]claude.Model, 0, len(availableModels))
for _, modelID := range availableModels {
models = append(models, claude.Model{
ID: modelID,
Type: "model",
DisplayName: modelID,
CreatedAt: "2024-01-01T00:00:00Z",
})
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": models,
})
return
}
// Fallback to default models
if platform == "openai" {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": openai.DefaultModels,
@@ -333,7 +451,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
return
}
// Default: Claude models
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": claude.DefaultModels,

View File

@@ -83,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h.concurrencyService.DecrementWaitCount(ctx, userID)
}
// IncrementAccountWaitCount increments the wait count for an account
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
}
// DecrementAccountWaitCount decrements the wait count for an account
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
@@ -126,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// Determine if ping is needed (streaming + ping format defined)
@@ -200,6 +215,11 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
}
}
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间

View File

@@ -198,13 +198,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
@@ -213,12 +217,48 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
account := selection.Account
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
stream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 5) forward (根据平台分流)
@@ -231,6 +271,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {

View File

@@ -20,6 +20,7 @@ type AdminHandlers struct {
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
}
// Handlers contains all HTTP handlers

View File

@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
@@ -156,14 +156,50 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// Forward request
@@ -171,6 +207,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {

View File

@@ -30,7 +30,6 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
Wechat *string `json:"wechat"`
}
// GetProfile handles getting user profile
@@ -99,7 +98,6 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{
Username: req.Username,
Wechat: req.Wechat,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {

View File

@@ -23,6 +23,7 @@ func ProvideAdminHandlers(
systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
System: systemHandler,
Subscription: subscriptionHandler,
Usage: usageHandler,
UserAttribute: userAttributeHandler,
}
}
@@ -107,6 +109,7 @@ var ProviderSet = wire.NewSet(
ProvideSystemHandler,
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,

View File

@@ -57,6 +57,7 @@ var geminiModels = []string{
"gemini-2.5-flash-lite",
"gemini-3-flash",
"gemini-3-pro-low",
"gemini-3-pro-high",
}
func TestMain(m *testing.M) {
@@ -641,6 +642,37 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"])
}
// TestClaudeMessagesWithGeminiModel 测试在 Claude 端点使用 Gemini 模型
// 验证:通过 /v1/messages 端点传入 gemini 模型名的场景(含前缀映射)
// 仅在 Antigravity 模式下运行ENDPOINT_PREFIX="/antigravity"
func TestClaudeMessagesWithGeminiModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Claude 端点调用 Gemini 模型
geminiViaClaude := []string{
"gemini-3-flash", // 直接支持
"gemini-3-pro-low", // 直接支持
"gemini-3-pro-high", // 直接支持
"gemini-3-pro", // 前缀映射 -> gemini-3-pro-high
"gemini-3-pro-preview", // 前缀映射 -> gemini-3-pro-high
}
for i, model := range geminiViaClaude {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Claude端点", func(t *testing.T) {
testClaudeMessage(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
testClaudeMessage(t, model, true)
})
}
}
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
// 验证Gemini 模型接受没有 signature 的 thinking block
func TestClaudeMessagesWithNoSignature(t *testing.T) {
@@ -738,3 +770,30 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
}
t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"])
}
// TestGeminiEndpointWithClaudeModel 测试通过 Gemini 端点调用 Claude 模型
// 仅在 Antigravity 模式下运行ENDPOINT_PREFIX="/antigravity"
func TestGeminiEndpointWithClaudeModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Gemini 端点调用 Claude 模型
claudeViaGemini := []string{
"claude-sonnet-4-5",
"claude-opus-4-5-thinking",
}
for i, model := range claudeViaGemini {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Gemini端点", func(t *testing.T) {
testGeminiGenerate(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true)
})
}
}

View File

@@ -54,6 +54,9 @@ type CustomToolSpec struct {
InputSchema map[string]any `json:"input_schema"`
}
// ClaudeCustomToolSpec 兼容旧命名MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素
type SystemBlock struct {
Type string `json:"type"`

View File

@@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 检测是否启用 thinking
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking除非未来支持完整签名链路
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
// 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil {
@@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq)
reqForGen := claudeReq
if requestedThinkingEnabled && !allowDummyThought {
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
// shallow copy to avoid mutating caller's request
clone := *claudeReq
clone.Thinking = nil
reqForGen = &clone
}
generationConfig := buildGenerationConfig(reqForGen)
// 4. 构建 tools
tools := buildTools(claudeReq.Tools)
@@ -148,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
if !hasThoughtPart && len(parts) > 0 {
// 在开头添加 dummy thinking block
parts = append([]GeminiPart{{
Text: "Thinking...",
Thought: true,
Text: "Thinking...",
Thought: true,
ThoughtSignature: dummyThoughtSignature,
}}, parts...)
}
}
@@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func isValidThoughtSignature(signature string) bool {
// 空字符串无效
if signature == "" {
return false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if len(signature) < 40 {
log.Printf("[Debug] Signature too short: len=%d", len(signature))
return false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for i, c := range signature {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
return false
}
}
return true
}
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
@@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
case "thinking":
part := GeminiPart{
Text: block.Thinking,
Thought: true,
}
// 保留原有 signatureClaude 模型需要有效的 signature
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature跳过无 signature 的 thinking block
log.Printf("Warning: skipping thinking block without signature for Claude model")
if allowDummyThought {
// Gemini 模型可以使用 dummy signature
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: dummyThoughtSignature,
})
continue
} else {
// Gemini 模型使用 dummy signature
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
// Claude 模型:仅在提供有效 signature 时保留 thinking block否则跳过以避免上游校验失败。
signature := strings.TrimSpace(block.Signature)
if signature == "" || signature == dummyThoughtSignature {
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
continue
}
if !isValidThoughtSignature(signature) {
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
}
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: signature,
})
case "image":
if block.Source != nil && block.Source.Type == "base64" {
@@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID: block.ID,
},
}
// 保留原有 signature或对 Gemini 模型使用 dummy signature
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
// 只有 Gemini 模型使用 dummy signature
// Claude 模型不设置 signature避免验证问题
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
@@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具
var funcDecls []GeminiFunctionDecl
for _, tool := range tools {
for i, tool := range tools {
// 跳过无效工具名称
if tool.Name == "" {
if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name")
continue
}
@@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
var inputSchema map[string]any
// 检查是否为 custom 类型工具 (MCP)
if tool.Type == "custom" && tool.Custom != nil {
// Custom 格式: 从 custom 字段获取 description 和 input_schema
if tool.Type == "custom" {
if tool.Custom == nil || tool.Custom.InputSchema == nil {
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
continue
}
description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else {
// 标准格式: 从顶层字段获取
description = tool.Description
@@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 清理 JSON Schema
params := cleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值
if params == nil {
params = map[string]any{
@@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
}
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name,
Description: description,
@@ -479,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{
"$schema": true,
"$id": true,
"$ref": true,
"additionalProperties": true,
"minLength": true,
"maxLength": true,
"minItems": true,
"maxItems": true,
"uniqueItems": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"pattern": true,
"format": true,
"default": true,
"strict": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// 元 schema 字段
"$schema": true,
"$id": true,
"$ref": true,
// 字符串验证Gemini 不支持)
"minLength": true,
"maxLength": true,
"pattern": true,
// 数字验证Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
// 数组验证Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schemaGemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
}
// cleanSchemaValue 递归清理 schema 值
@@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any {
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalPropertiesClaude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val)
}

View File

@@ -0,0 +1,179 @@
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &CustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &CustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &CustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}

View File

@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",

View File

@@ -0,0 +1,157 @@
package geminicli
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
// DriveStorageInfo represents Google Drive storage quota information
type DriveStorageInfo struct {
Limit int64 `json:"limit"` // Storage limit in bytes
Usage int64 `json:"usage"` // Current usage in bytes
}
// DriveClient interface for Google Drive API operations
type DriveClient interface {
GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error)
}
type driveClient struct{}
// NewDriveClient creates a new Drive API client
func NewDriveClient() DriveClient {
return &driveClient{}
}
// GetStorageQuota fetches storage quota from Google Drive API
func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error) {
const driveAPIURL = "https://www.googleapis.com/drive/v3/about?fields=storageQuota"
req, err := http.NewRequestWithContext(ctx, "GET", driveAPIURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// Get HTTP client with proxy support
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 10 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
}
sleepWithContext := func(d time.Duration) error {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
var resp *http.Response
maxRetries := 3
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for attempt := 0; attempt < maxRetries; attempt++ {
if ctx.Err() != nil {
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
}
resp, err = client.Do(req)
if err != nil {
// Network error retry
if attempt < maxRetries-1 {
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
if err := sleepWithContext(backoff + jitter); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
}
// Success
if resp.StatusCode == http.StatusOK {
break
}
// Retry 429, 500, 502, 503 with exponential backoff + jitter
if (resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusInternalServerError ||
resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
if err := func() error {
defer func() { _ = resp.Body.Close() }()
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
return sleepWithContext(backoff + jitter)
}(); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
break
}
if resp == nil {
return nil, fmt.Errorf("request failed: no response received")
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
statusText := http.StatusText(resp.StatusCode)
if statusText == "" {
statusText = resp.Status
}
fmt.Printf("[DriveClient] Drive API error: status=%d, msg=%s\n", resp.StatusCode, statusText)
// 只返回通用错误
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
}
defer func() { _ = resp.Body.Close() }()
// Parse response
var result struct {
StorageQuota struct {
Limit string `json:"limit"` // Can be string or number
Usage string `json:"usage"`
} `json:"storageQuota"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
// Parse limit and usage (handle both string and number formats)
var limit, usage int64
if result.StorageQuota.Limit != "" {
if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
limit = val
}
}
if result.StorageQuota.Usage != "" {
if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
usage = val
}
}
return &DriveStorageInfo{
Limit: limit,
Usage: usage,
}, nil
}

View File

@@ -0,0 +1,18 @@
package geminicli
import "testing"
func TestDriveStorageInfo(t *testing.T) {
// 测试 DriveStorageInfo 结构体
info := &DriveStorageInfo{
Limit: 100 * 1024 * 1024 * 1024, // 100GB
Usage: 50 * 1024 * 1024 * 1024, // 50GB
}
if info.Limit != 100*1024*1024*1024 {
t.Errorf("Expected limit 100GB, got %d", info.Limit)
}
if info.Usage != 50*1024*1024*1024 {
t.Errorf("Expected usage 50GB, got %d", info.Usage)
}
}

View File

@@ -124,6 +124,90 @@ func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Acc
return &accounts[0], nil
}
func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
if len(ids) == 0 {
return []*service.Account{}, nil
}
// De-duplicate while preserving order of first occurrence.
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return []*service.Account{}, nil
}
entAccounts, err := r.client.Account.
Query().
Where(dbaccount.IDIn(uniqueIDs...)).
WithProxy().
All(ctx)
if err != nil {
return nil, err
}
if len(entAccounts) == 0 {
return []*service.Account{}, nil
}
accountIDs := make([]int64, 0, len(entAccounts))
entByID := make(map[int64]*dbent.Account, len(entAccounts))
for _, acc := range entAccounts {
entByID[acc.ID] = acc
accountIDs = append(accountIDs, acc.ID)
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
}
outByID := make(map[int64]*service.Account, len(entAccounts))
for _, entAcc := range entAccounts {
out := accountEntityToService(entAcc)
if out == nil {
continue
}
// Prefer the preloaded proxy edge when available.
if entAcc.Edges.Proxy != nil {
out.Proxy = proxyEntityToService(entAcc.Edges.Proxy)
}
if groups, ok := groupsByAccount[entAcc.ID]; ok {
out.Groups = groups
}
if groupIDs, ok := groupIDsByAccount[entAcc.ID]; ok {
out.GroupIDs = groupIDs
}
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
outByID[entAcc.ID] = out
}
// Preserve input order (first occurrence), and ignore missing IDs.
out := make([]*service.Account, 0, len(uniqueIDs))
for _, id := range uniqueIDs {
if _, ok := entByID[id]; !ok {
continue
}
if acc, ok := outByID[id]; ok && acc != nil {
out = append(out, acc)
}
}
return out, nil
}
// ExistsByID 检查指定 ID 的账号是否存在。
// 相比 GetByID此方法性能更优因为
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值

View File

@@ -294,7 +294,6 @@ func userEntityToService(u *dbent.User) *service.User {
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,

View File

@@ -2,7 +2,9 @@ package repository
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
@@ -27,6 +29,8 @@ const (
userSlotKeyPrefix = "concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
// 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes = 15
@@ -112,33 +116,112 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
return 1
`)
// incrementAccountWaitScript - account-level wait queue count
incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// getAccountsLoadBatchScript - batch load query (read-only)
// ARGV[1] = slot TTL (seconds, retained for compatibility)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`)
)
type concurrencyCache struct {
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
waitQueueTTLSeconds int // 等待队列过期时间(秒)
}
// NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间分钟0 或负数使用默认值 15 分钟
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache {
// waitQueueTTLSeconds: 等待队列过期时间0 或负数使用 slot TTL
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes
}
if waitQueueTTLSeconds <= 0 {
waitQueueTTLSeconds = slotTTLMinutes * 60
}
return &concurrencyCache{
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
waitQueueTTLSeconds: waitQueueTTLSeconds,
}
}
@@ -155,6 +238,10 @@ func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
func accountWaitKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
@@ -225,3 +312,75 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
key := accountWaitKey(accountID)
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
key := accountWaitKey(accountID)
val, err := c.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
return 0, err
}
if errors.Is(err, redis.Nil) {
return 0, nil
}
return val, nil
}
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
if len(accounts) == 0 {
return map[int64]*service.AccountLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, acc := range accounts {
args = append(args, acc.ID, acc.MaxConcurrency)
}
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
if err != nil {
return nil, err
}
loadMap := make(map[int64]*service.AccountLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
}
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[accountID] = &service.AccountLoadInfo{
AccountID: accountID,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
}
}
return loadMap, nil
}
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
key := accountSlotKey(accountID)
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err
}

View File

@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) {
_ = rdb.Close()
}()
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache)
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
ctx := context.Background()
for _, size := range []int{10, 100, 1000} {

View File

@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes)
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
@@ -218,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
accountID := int64(30)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
require.False(s.T(), ok, "expected account wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL account waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
@@ -232,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
// Setup: Create accounts with different load states
account1 := int64(100)
account2 := int64(101)
account3 := int64(102)
// Account 1: 2/3 slots used, 1 waiting
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 2: 1/2 slots used, 0 waiting
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts := []service.AccountWithConcurrency{
{ID: account1, MaxConcurrency: 3},
{ID: account2, MaxConcurrency: 2},
{ID: account3, MaxConcurrency: 1},
}
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
require.NoError(s.T(), err)
require.Len(s.T(), loadMap, 3)
// Verify account1: (2 + 1) / 3 = 100%
load1 := loadMap[account1]
require.NotNil(s.T(), load1)
require.Equal(s.T(), account1, load1.AccountID)
require.Equal(s.T(), 2, load1.CurrentConcurrency)
require.Equal(s.T(), 1, load1.WaitingCount)
require.Equal(s.T(), 100, load1.LoadRate)
// Verify account2: (1 + 0) / 2 = 50%
load2 := loadMap[account2]
require.NotNil(s.T(), load2)
require.Equal(s.T(), account2, load2.AccountID)
require.Equal(s.T(), 1, load2.CurrentConcurrency)
require.Equal(s.T(), 0, load2.WaitingCount)
require.Equal(s.T(), 50, load2.LoadRate)
// Verify account3: (0 + 0) / 1 = 0%
load3 := loadMap[account3]
require.NotNil(s.T(), load3)
require.Equal(s.T(), account3, load3.AccountID)
require.Equal(s.T(), 0, load3.CurrentConcurrency)
require.Equal(s.T(), 0, load3.WaitingCount)
require.Equal(s.T(), 0, load3.LoadRate)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
// Test with empty account list
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
require.NoError(s.T(), err)
require.Empty(s.T(), loadMap)
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
accountID := int64(200)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
// Acquire 3 slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Verify 3 slots exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, cur)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now := time.Now().Unix()
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
require.NoError(s.T(), err)
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
require.NoError(s.T(), err)
// Run cleanup
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify only 1 slot remains (req3)
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur)
// Verify req3 still exists
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Len(s.T(), members, 1)
require.Equal(s.T(), "req3", members[0])
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
accountID := int64(201)
// Acquire 2 fresh slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Run cleanup (should not remove anything)
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify both slots still exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}

View File

@@ -40,7 +40,6 @@ func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *servic
SetBalance(u.Balance).
SetConcurrency(u.Concurrency).
SetUsername(u.Username).
SetWechat(u.Wechat).
SetNotes(u.Notes)
if !u.CreatedAt.IsZero() {
create.SetCreatedAt(u.CreatedAt)

View File

@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum)
return fmt.Errorf(
"migration %s checksum mismatch (db=%s file=%s)\n"+
"This means the migration file was modified after being applied to the database.\n"+
"Solutions:\n"+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
name, existing, checksum, name, name,
)
}
continue // 迁移已应用且校验和匹配,跳过
}

View File

@@ -23,7 +23,6 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// users: columns required by repository queries
requireColumn(t, tx, "users", "username", "character varying", 100, false)
requireColumn(t, tx, "users", "wechat", "character varying", 100, false)
requireColumn(t, tx, "users", "notes", "text", 0, false)
// accounts: schedulable and rate-limit fields

View File

@@ -0,0 +1,385 @@
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// UserAttributeDefinitionRepository implementation
type userAttributeDefinitionRepository struct {
client *dbent.Client
}
// NewUserAttributeDefinitionRepository creates a new repository instance
func NewUserAttributeDefinitionRepository(client *dbent.Client) service.UserAttributeDefinitionRepository {
return &userAttributeDefinitionRepository{client: client}
}
func (r *userAttributeDefinitionRepository) Create(ctx context.Context, def *service.UserAttributeDefinition) error {
client := clientFromContext(ctx, r.client)
created, err := client.UserAttributeDefinition.Create().
SetKey(def.Key).
SetName(def.Name).
SetDescription(def.Description).
SetType(string(def.Type)).
SetOptions(toEntOptions(def.Options)).
SetRequired(def.Required).
SetValidation(toEntValidation(def.Validation)).
SetPlaceholder(def.Placeholder).
SetEnabled(def.Enabled).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrAttributeKeyExists)
}
def.ID = created.ID
def.DisplayOrder = created.DisplayOrder
def.CreatedAt = created.CreatedAt
def.UpdatedAt = created.UpdatedAt
return nil
}
func (r *userAttributeDefinitionRepository) GetByID(ctx context.Context, id int64) (*service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
e, err := client.UserAttributeDefinition.Query().
Where(userattributedefinition.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
return defEntityToService(e), nil
}
func (r *userAttributeDefinitionRepository) GetByKey(ctx context.Context, key string) (*service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
e, err := client.UserAttributeDefinition.Query().
Where(userattributedefinition.KeyEQ(key)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
return defEntityToService(e), nil
}
func (r *userAttributeDefinitionRepository) Update(ctx context.Context, def *service.UserAttributeDefinition) error {
client := clientFromContext(ctx, r.client)
updated, err := client.UserAttributeDefinition.UpdateOneID(def.ID).
SetName(def.Name).
SetDescription(def.Description).
SetType(string(def.Type)).
SetOptions(toEntOptions(def.Options)).
SetRequired(def.Required).
SetValidation(toEntValidation(def.Validation)).
SetPlaceholder(def.Placeholder).
SetDisplayOrder(def.DisplayOrder).
SetEnabled(def.Enabled).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, service.ErrAttributeKeyExists)
}
def.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *userAttributeDefinitionRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeDefinition.Delete().
Where(userattributedefinition.IDEQ(id)).
Exec(ctx)
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
func (r *userAttributeDefinitionRepository) List(ctx context.Context, enabledOnly bool) ([]service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
q := client.UserAttributeDefinition.Query()
if enabledOnly {
q = q.Where(userattributedefinition.EnabledEQ(true))
}
entities, err := q.Order(dbent.Asc(userattributedefinition.FieldDisplayOrder)).All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeDefinition, 0, len(entities))
for _, e := range entities {
result = append(result, *defEntityToService(e))
}
return result, nil
}
func (r *userAttributeDefinitionRepository) UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error {
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for id, order := range orders {
if _, err := tx.UserAttributeDefinition.UpdateOneID(id).
SetDisplayOrder(order).
Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
}
return tx.Commit()
}
func (r *userAttributeDefinitionRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
client := clientFromContext(ctx, r.client)
return client.UserAttributeDefinition.Query().
Where(userattributedefinition.KeyEQ(key)).
Exist(ctx)
}
// UserAttributeValueRepository implementation
type userAttributeValueRepository struct {
client *dbent.Client
}
// NewUserAttributeValueRepository creates a new repository instance
func NewUserAttributeValueRepository(client *dbent.Client) service.UserAttributeValueRepository {
return &userAttributeValueRepository{client: client}
}
func (r *userAttributeValueRepository) GetByUserID(ctx context.Context, userID int64) ([]service.UserAttributeValue, error) {
client := clientFromContext(ctx, r.client)
entities, err := client.UserAttributeValue.Query().
Where(userattributevalue.UserIDEQ(userID)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeValue, 0, len(entities))
for _, e := range entities {
result = append(result, service.UserAttributeValue{
ID: e.ID,
UserID: e.UserID,
AttributeID: e.AttributeID,
Value: e.Value,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
})
}
return result, nil
}
func (r *userAttributeValueRepository) GetByUserIDs(ctx context.Context, userIDs []int64) ([]service.UserAttributeValue, error) {
if len(userIDs) == 0 {
return []service.UserAttributeValue{}, nil
}
client := clientFromContext(ctx, r.client)
entities, err := client.UserAttributeValue.Query().
Where(userattributevalue.UserIDIn(userIDs...)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeValue, 0, len(entities))
for _, e := range entities {
result = append(result, service.UserAttributeValue{
ID: e.ID,
UserID: e.UserID,
AttributeID: e.AttributeID,
Value: e.Value,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
})
}
return result, nil
}
func (r *userAttributeValueRepository) UpsertBatch(ctx context.Context, userID int64, inputs []service.UpdateUserAttributeInput) error {
if len(inputs) == 0 {
return nil
}
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for _, input := range inputs {
// Use upsert (ON CONFLICT DO UPDATE)
err := tx.UserAttributeValue.Create().
SetUserID(userID).
SetAttributeID(input.AttributeID).
SetValue(input.Value).
OnConflictColumns(userattributevalue.FieldUserID, userattributevalue.FieldAttributeID).
UpdateValue().
UpdateUpdatedAt().
Exec(ctx)
if err != nil {
return err
}
}
return tx.Commit()
}
func (r *userAttributeValueRepository) DeleteByAttributeID(ctx context.Context, attributeID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeValue.Delete().
Where(userattributevalue.AttributeIDEQ(attributeID)).
Exec(ctx)
return err
}
func (r *userAttributeValueRepository) DeleteByUserID(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeValue.Delete().
Where(userattributevalue.UserIDEQ(userID)).
Exec(ctx)
return err
}
// Helper functions for entity to service conversion
func defEntityToService(e *dbent.UserAttributeDefinition) *service.UserAttributeDefinition {
if e == nil {
return nil
}
return &service.UserAttributeDefinition{
ID: e.ID,
Key: e.Key,
Name: e.Name,
Description: e.Description,
Type: service.UserAttributeType(e.Type),
Options: toServiceOptions(e.Options),
Required: e.Required,
Validation: toServiceValidation(e.Validation),
Placeholder: e.Placeholder,
DisplayOrder: e.DisplayOrder,
Enabled: e.Enabled,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}
}
// Type conversion helpers (map types <-> service types)
func toEntOptions(opts []service.UserAttributeOption) []map[string]any {
if opts == nil {
return []map[string]any{}
}
result := make([]map[string]any, len(opts))
for i, o := range opts {
result[i] = map[string]any{"value": o.Value, "label": o.Label}
}
return result
}
func toServiceOptions(opts []map[string]any) []service.UserAttributeOption {
if opts == nil {
return []service.UserAttributeOption{}
}
result := make([]service.UserAttributeOption, len(opts))
for i, o := range opts {
result[i] = service.UserAttributeOption{
Value: getString(o, "value"),
Label: getString(o, "label"),
}
}
return result
}
func toEntValidation(v service.UserAttributeValidation) map[string]any {
result := map[string]any{}
if v.MinLength != nil {
result["min_length"] = *v.MinLength
}
if v.MaxLength != nil {
result["max_length"] = *v.MaxLength
}
if v.Min != nil {
result["min"] = *v.Min
}
if v.Max != nil {
result["max"] = *v.Max
}
if v.Pattern != nil {
result["pattern"] = *v.Pattern
}
if v.Message != nil {
result["message"] = *v.Message
}
return result
}
func toServiceValidation(v map[string]any) service.UserAttributeValidation {
result := service.UserAttributeValidation{}
if val := getInt(v, "min_length"); val != nil {
result.MinLength = val
}
if val := getInt(v, "max_length"); val != nil {
result.MaxLength = val
}
if val := getInt(v, "min"); val != nil {
result.Min = val
}
if val := getInt(v, "max"); val != nil {
result.Max = val
}
if val := getStringPtr(v, "pattern"); val != nil {
result.Pattern = val
}
if val := getStringPtr(v, "message"); val != nil {
result.Message = val
}
return result
}
// Helper functions for type conversion
func getString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getStringPtr(m map[string]any, key string) *string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return &s
}
}
return nil
}
func getInt(m map[string]any, key string) *int {
if v, ok := m[key]; ok {
switch n := v.(type) {
case int:
return &n
case int64:
i := int(n)
return &i
case float64:
i := int(n)
return &i
}
}
return nil
}

View File

@@ -9,6 +9,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -50,7 +51,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
created, err := txClient.User.Create().
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetWechat(userIn.Wechat).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
@@ -133,7 +133,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
updated, err := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetWechat(userIn.Wechat).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
@@ -171,28 +170,38 @@ func (r *userRepository) Delete(ctx context.Context, id int64) error {
}
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
return r.ListWithFilters(ctx, params, service.UserListFilters{})
}
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
q := r.client.User.Query()
if status != "" {
q = q.Where(dbuser.StatusEQ(status))
if filters.Status != "" {
q = q.Where(dbuser.StatusEQ(filters.Status))
}
if role != "" {
q = q.Where(dbuser.RoleEQ(role))
if filters.Role != "" {
q = q.Where(dbuser.RoleEQ(filters.Role))
}
if search != "" {
if filters.Search != "" {
q = q.Where(
dbuser.Or(
dbuser.EmailContainsFold(search),
dbuser.UsernameContainsFold(search),
dbuser.WechatContainsFold(search),
dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(filters.Search),
),
)
}
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
}
q = q.Where(dbuser.IDIn(allowedUserIDs...))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
@@ -252,6 +261,59 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
if len(attrs) == 0 {
return nil
}
// For each attribute filter, get the set of matching user IDs
// Then intersect all sets to get users matching ALL filters
var resultSet map[int64]struct{}
first := true
for attrID, value := range attrs {
// Query user_attribute_values for this attribute
values, err := r.client.UserAttributeValue.Query().
Where(
userattributevalue.AttributeIDEQ(attrID),
userattributevalue.ValueContainsFold(value),
).
All(ctx)
if err != nil {
continue
}
currentSet := make(map[int64]struct{}, len(values))
for _, v := range values {
currentSet[v.UserID] = struct{}{}
}
if first {
resultSet = currentSet
first = false
} else {
// Intersect with previous results
for userID := range resultSet {
if _, ok := currentSet[userID]; !ok {
delete(resultSet, userID)
}
}
}
// Early exit if no users match
if len(resultSet) == 0 {
return nil
}
}
result := make([]int64, 0, len(resultSet))
for userID := range resultSet {
result = append(result, userID)
}
return result
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)

View File

@@ -166,7 +166,7 @@ func (s *UserRepoSuite) TestListWithFilters_Status() {
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(service.StatusActive, users[0].Status)
@@ -176,7 +176,7 @@ func (s *UserRepoSuite) TestListWithFilters_Role() {
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Role: service.RoleAdmin})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(service.RoleAdmin, users[0].Role)
@@ -186,7 +186,7 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "alice"})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Contains(users[0].Email, "alice")
@@ -196,22 +196,12 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "john"})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("JohnDoe", users[0].Username)
}
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
s.mustCreateUser(&service.User{Email: "w1@test.com", Wechat: "wx_hello"})
s.mustCreateUser(&service.User{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("wx_hello", users[0].Wechat)
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
groupActive := s.mustCreateGroup("g-sub-active")
@@ -226,7 +216,7 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "sub@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Len(users, 1, "expected 1 user")
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
@@ -238,7 +228,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
s.mustCreateUser(&service.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
@@ -246,7 +235,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
target := s.mustCreateUser(&service.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
@@ -257,7 +245,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
Status: service.StatusDisabled,
})
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@")
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
@@ -448,7 +436,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := s.mustCreateUser(&service.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
@@ -456,7 +443,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user2 := s.mustCreateUser(&service.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
@@ -501,7 +487,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")

View File

@@ -15,7 +15,14 @@ import (
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// 性能优化TTL 可配置,支持长时间运行的 LLM 请求场景
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes)
waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds())
if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout {
waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds())
}
if waitTTLSeconds <= 0 {
waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60
}
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
}
// ProviderSet is the Wire provider set for all repositories
@@ -29,6 +36,8 @@ var ProviderSet = wire.NewSet(
NewUsageLogRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository,
// Cache implementations
NewGatewayCache,

View File

@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
"id": 1,
"email": "alice@example.com",
"username": "alice",
"wechat": "wx_alice",
"notes": "hello",
"role": "user",
"balance": 12.5,
@@ -348,7 +347,6 @@ func newContractDeps(t *testing.T) *contractDeps {
ID: 1,
Email: "alice@example.com",
Username: "alice",
Wechat: "wx_alice",
Notes: "hello",
Role: service.RoleUser,
Balance: 12.5,
@@ -503,7 +501,7 @@ func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationPar
return nil, nil, errors.New("not implemented")
}
func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}

View File

@@ -54,6 +54,9 @@ func RegisterAdminRoutes(
// 使用记录管理
registerUsageRoutes(admin, h)
// 用户属性管理
registerUserAttributeRoutes(admin, h)
}
}
@@ -82,6 +85,10 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
users.PUT("/:id/attributes", h.Admin.UserAttribute.UpdateUserAttributes)
}
}
@@ -110,6 +117,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
@@ -119,6 +127,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Claude OAuth routes
@@ -242,3 +251,15 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs := admin.Group("/user-attributes")
{
attrs.GET("", h.Admin.UserAttribute.ListDefinitions)
attrs.POST("", h.Admin.UserAttribute.CreateDefinition)
attrs.POST("/batch", h.Admin.UserAttribute.GetBatchUserAttributes)
attrs.PUT("/reorder", h.Admin.UserAttribute.ReorderDefinitions)
attrs.PUT("/:id", h.Admin.UserAttribute.UpdateDefinition)
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}

View File

@@ -3,6 +3,7 @@ package service
import (
"encoding/json"
"strconv"
"strings"
"time"
)
@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
}
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
return "code_assist"
}
return oauthType
}
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
}
func (a *Account) IsGeminiCodeAssist() bool {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return false
}
oauthType := a.GeminiOAuthType()
if oauthType == "" {
return strings.TrimSpace(a.GetCredential("project_id")) != ""
}
return oauthType == "code_assist"
}
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}

View File

@@ -17,6 +17,9 @@ var (
type AccountRepository interface {
Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*Account, error)
// GetByIDs fetches accounts by IDs in a single query.
// It should return all accounts found (missing IDs are ignored).
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
ExistsByID(ctx context.Context, id int64) (bool, error)
// GetByCRSAccountID finds an account previously synced from CRS.

View File

@@ -40,6 +40,10 @@ func (s *accountRepoStub) GetByID(ctx context.Context, id int64) (*Account, erro
panic("unexpected GetByID call")
}
func (s *accountRepoStub) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
panic("unexpected GetByIDs call")
}
// ExistsByID 返回预设的存在性检查结果。
// 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。
func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) {

View File

@@ -93,10 +93,12 @@ type UsageProgress struct {
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
}
// ClaudeUsageResponse Anthropic API返回的usage结构
@@ -122,17 +124,19 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务
type AccountUsageService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
return &AccountUsageService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
}
}
@@ -146,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("get account failed: %w", err)
}
if account.Platform == PlatformGemini {
return s.getGeminiUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage有profile scope
if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse
@@ -192,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
}
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{
UpdatedAt: &now,
}
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
return usage, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return usage, nil
}
start := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
return usage, nil
}
// addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
@@ -388,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// Setup Token无法获取7d数据
return info
}
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
if limit <= 0 {
return nil
}
utilization := (float64(used) / float64(limit)) * 100
remainingSeconds := int(resetAt.Sub(now).Seconds())
if remainingSeconds < 0 {
remainingSeconds = 0
}
resetCopy := resetAt
return &UsageProgress{
Utilization: utilization,
ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds,
WindowStats: &WindowStats{
Requests: used,
Tokens: tokens,
Cost: cost,
},
}
}

View File

@@ -13,7 +13,7 @@ import (
// AdminService interface defines admin management operations
type AdminService interface {
// User management
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error)
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
@@ -35,6 +35,7 @@ type AdminService interface {
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error
@@ -69,7 +70,6 @@ type CreateUserInput struct {
Email string
Password string
Username string
Wechat string
Notes string
Balance float64
Concurrency int
@@ -80,7 +80,6 @@ type UpdateUserInput struct {
Email string
Password string
Username *string
Wechat *string
Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
@@ -251,9 +250,9 @@ func NewAdminService(
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) {
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, 0, err
}
@@ -268,7 +267,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
user := &User{
Email: input.Email,
Username: input.Username,
Wechat: input.Wechat,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
@@ -310,9 +308,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if input.Username != nil {
user.Username = *input.Username
}
if input.Wechat != nil {
user.Wechat = *input.Wechat
}
if input.Notes != nil {
user.Notes = *input.Notes
}
@@ -611,6 +606,19 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
return s.accountRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
if len(ids) == 0 {
return []*Account{}, nil
}
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
if err != nil {
return nil, fmt.Errorf("failed to get accounts by IDs: %w", err)
}
return accounts, nil
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
account := &Account{
Name: input.Name,

View File

@@ -18,7 +18,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
Email: "user@test.com",
Password: "strong-pass",
Username: "tester",
Wechat: "wx",
Notes: "note",
Balance: 12.5,
Concurrency: 7,
@@ -31,7 +30,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
require.Equal(t, int64(10), user.ID)
require.Equal(t, input.Email, user.Email)
require.Equal(t, input.Username, user.Username)
require.Equal(t, input.Wechat, user.Wechat)
require.Equal(t, input.Notes, user.Notes)
require.Equal(t, input.Balance, user.Balance)
require.Equal(t, input.Concurrency, user.Concurrency)

View File

@@ -66,7 +66,7 @@ func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationPar
panic("unexpected List call")
}
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error) {
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}

View File

@@ -25,7 +25,7 @@ const (
antigravityRetryMaxDelay = 16 * time.Second
)
// Antigravity 直接支持的模型
// Antigravity 直接支持的模型(精确匹配透传)
var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true,
@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
"gemini-3-flash": true,
"gemini-3-pro-low": true,
"gemini-3-pro-high": true,
"gemini-3-pro-preview": true,
"gemini-3-pro-image": true,
}
// Antigravity 系统默认模型映射表(不支持 → 支持
var antigravityModelMapping = map[string]string{
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5",
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking",
"claude-opus-4": "claude-opus-4-5-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
"claude-haiku-4": "gemini-3-flash",
"claude-haiku-4-5": "gemini-3-flash",
"claude-3-haiku-20240307": "gemini-3-flash",
"claude-haiku-4-5-20251001": "gemini-3-flash",
// 生图模型:官方名 → Antigravity 内部名
"gemini-3-pro-image-preview": "gemini-3-pro-image",
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
var antigravityPrefixMapping = []struct {
prefix string
target string
}{
// 长前缀优先
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
{"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "gemini-3-flash"},
{"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
}
// getMappedModel 获取映射后的模型名
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
// 1. 优先使用账户级映射(复用现有方法
// 1. 账户级映射(用户自定义优先
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
return mapped
}
// 2. 系统默认映射
if mapped, ok := antigravityModelMapping[requestedModel]; ok {
return mapped
}
// 3. Gemini 模型透传
if strings.HasPrefix(requestedModel, "gemini-") {
// 2. 直接支持的模型透传
if antigravitySupportedModels[requestedModel] {
return requestedModel
}
// 4. Claude 前缀透传直接支持的模型
if antigravitySupportedModels[requestedModel] {
// 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview
for _, pm := range antigravityPrefixMapping {
if strings.HasPrefix(requestedModel, pm.prefix) {
return pm.target
}
}
// 4. Gemini 模型透传(未匹配到前缀的 gemini 模型)
if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel
}
@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
}
// IsModelSupported 检查模型是否被支持
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射)
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// TestConnectionResult 测试连接结果
@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err)
}
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action
action := "generateContent"
if claudeReq.Stream {

View File

@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
expected: "claude-sonnet-4-5",
},
// 3. Gemini 透传

View File

@@ -18,6 +18,11 @@ type ConcurrencyCache interface {
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
// 用户槽位管理
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
@@ -27,6 +32,12 @@ type ConcurrencyCache interface {
// 等待队列计数(只在首次创建时设置 TTL
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
}
// generateRequestID generates a unique request ID for concurrency slot tracking
@@ -61,6 +72,18 @@ type AcquireResult struct {
ReleaseFunc func() // Must be called when done (typically via defer)
}
type AccountWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
@@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
}
}
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {
return true, nil
}
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
if err != nil {
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
return true, nil
}
return result, nil
}
// DecrementAccountWaitCount decrements the wait queue counter for an account.
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
if s.cache == nil {
return
}
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
}
}
// GetAccountWaitingCount gets current wait queue count for an account.
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetAccountWaitingCount(ctx, accountID)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int {
@@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int {
return userConcurrency + defaultExtraWaitSlots
}
// GetAccountsLoadBatch returns load info for multiple accounts.
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
return nil
}
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
}
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
return
}
runCleanup := func() {
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
accounts, err := accountRepo.ListSchedulable(listCtx)
cancel()
if err != nil {
log.Printf("Warning: list schedulable accounts failed: %v", err)
return
}
for _, account := range accounts {
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
accountCancel()
if err != nil {
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
}
}
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
runCleanup()
for range ticker.C {
runCleanup()
}
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {

View File

@@ -91,6 +91,9 @@ const (
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
)
// Admin API Key prefix (distinct from user "sk-" keys)

View File

@@ -32,6 +32,16 @@ func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Ac
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForPlatform) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
var result []*Account
for _, id := range ids {
if acc, ok := m.accountsByID[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForPlatform) ExistsByID(ctx context.Context, id int64) (bool, error) {
if m.accountsByID == nil {
return false, nil
@@ -261,6 +271,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
}
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
ctx := context.Background()
@@ -576,6 +614,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
ctx := context.Background()
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
})
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
@@ -783,3 +847,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
})
}
}
// mockConcurrencyService for testing
type mockConcurrencyService struct {
accountLoads map[int64]*AccountLoadInfo
accountWaitCounts map[int64]int
acquireResults map[int64]bool
}
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if m.accountLoads == nil {
return map[int64]*AccountLoadInfo{}, nil
}
result := make(map[int64]*AccountLoadInfo)
for _, acc := range accounts {
if load, ok := m.accountLoads[acc.ID]; ok {
result[acc.ID] = load
} else {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
}
return result, nil
}
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if m.accountWaitCounts == nil {
return 0, nil
}
return m.accountWaitCounts[accountID], nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // No concurrency service
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
}

View File

@@ -13,12 +13,14 @@ import (
"log"
"net/http"
"regexp"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
@@ -66,6 +68,20 @@ type GatewayCache interface {
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
}
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
Timeout time.Duration
MaxWaiting int
}
type AccountSelectionResult struct {
Account *Account
Acquired bool
ReleaseFunc func()
WaitPlan *AccountWaitPlan // nil means no wait allowed
}
// ClaudeUsage 表示Claude API返回的usage信息
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
@@ -108,6 +124,7 @@ type GatewayService struct {
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
}
// NewGatewayService creates a new GatewayService
@@ -119,6 +136,7 @@ func NewGatewayService(
userSubRepo UserSubscriptionRepository,
cache GatewayCache,
cfg *config.Config,
concurrencyService *ConcurrencyService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
@@ -134,6 +152,7 @@ func NewGatewayService(
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
@@ -183,6 +202,14 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return ""
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
return nil
}
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
@@ -332,8 +359,354 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
if err != nil {
return nil, err
}
preferOAuth := platform == PlatformGemini
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
for _, acc := range candidates {
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true
}
}
return nil, false
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil
}
if groupID != nil {
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return "", false, fmt.Errorf("get group failed: %w", err)
}
return group.Platform, false, nil
}
return PlatformAnthropic, false, nil
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
platforms := []string{platform, PlatformAntigravity}
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
filtered = append(filtered, acc)
}
return filtered, useMixed, nil
}
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
if err == nil && len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
return nil, useMixed, err
}
return accounts, useMixed, nil
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
}
if useMixed {
if account.Platform == platform {
return true
}
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
}
return account.Platform == platform
}
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
switch {
case a.LastUsedAt == nil && b.LastUsedAt != nil:
return true
case a.LastUsedAt != nil && b.LastUsedAt == nil:
return false
case a.LastUsedAt == nil && b.LastUsedAt == nil:
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
default:
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
@@ -389,7 +762,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
@@ -419,6 +794,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
platforms := []string{nativePlatform, PlatformAntigravity}
preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
@@ -478,7 +854,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
@@ -515,24 +893,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func IsAntigravityModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// GetAccessToken 获取账号凭证
@@ -684,6 +1048,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
if s.shouldFailoverOn400(respBody) {
if s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Account %d: 400 error, attempting failover: %s",
account.ID,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
} else {
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
}
return s.handleErrorResponse(ctx, resp, c, account)
}
@@ -786,6 +1174,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
@@ -838,6 +1233,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
return claude.DefaultBetaHeader
}
func requestNeedsBetaFeatures(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 缺少/错误的 beta header换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if strings.Contains(msg, "anthropic-beta") ||
strings.Contains(msg, "beta feature") ||
strings.Contains(msg, "requires beta") {
return true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
return true
}
return false
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
inner := strings.TrimSpace(m)
// 有些上游会把完整 JSON 作为字符串塞进 message
if strings.HasPrefix(inner, "{") {
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
return m
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
@@ -850,6 +1322,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
switch resp.StatusCode {
case 400:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream 400 error (account=%d platform=%s type=%s): %s",
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401:
@@ -1329,6 +1811,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 记录上游错误摘要便于排障(不回显请求内容)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回简化的错误响应
errMsg := "Upstream request failed"
switch resp.StatusCode {
@@ -1409,6 +1903,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
@@ -1424,3 +1925,58 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
},
})
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
} else {
accounts, err = s.accountRepo.ListSchedulable(ctx)
}
if err != nil || len(accounts) == 0 {
return nil
}
// Filter by platform if specified
if platform != "" {
filtered := make([]Account, 0)
for _, acc := range accounts {
if acc.Platform == platform {
filtered = append(filtered, acc)
}
}
accounts = filtered
}
// Collect unique models from all accounts
modelSet := make(map[string]struct{})
hasAnyMapping := false
for _, acc := range accounts {
mapping := acc.GetModelMapping()
if len(mapping) > 0 {
hasAnyMapping = true
for model := range mapping {
modelSet[model] = struct{}{}
}
}
}
// If no account has model_mapping, return nil (use default)
if !hasAnyMapping {
return nil
}
// Convert to slice
models := make([]string, 0, len(modelSet))
for model := range modelSet {
models = append(models, model)
}
return models
}

View File

@@ -116,8 +116,20 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
valid = true
}
if valid {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
usable := true
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
if !ok {
usable = false
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
}
@@ -157,6 +169,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
continue
}
}
if selected == nil {
selected = acc
continue
@@ -1886,13 +1907,44 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
if statusCode != 429 {
return
}
oauthType := account.GeminiOAuthType()
tierID := account.GeminiTierID()
projectID := strings.TrimSpace(account.GetCredential("project_id"))
isCodeAssist := account.IsGeminiCodeAssist()
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
ra := time.Now().Add(5 * time.Minute)
// 根据账号类型使用不同的默认重置时间
var ra time.Time
if isCodeAssist {
// Code Assist: fallback cooldown by tier
cooldown := geminiCooldownForTier(tierID)
if s.rateLimitService != nil {
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
}
ra = time.Now().Add(cooldown)
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
} else {
// API Key / AI Studio OAuth: PST 午夜
if ts := nextGeminiDailyResetUnix(); ts != nil {
ra = time.Unix(*ts, 0)
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
} else {
// 兜底5 分钟
ra = time.Now().Add(5 * time.Minute)
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
}
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
return
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
// 使用解析到的重置时间
resetTime := time.Unix(*resetAt, 0)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
account.ID, resetTime, oauthType, tierID)
}
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
@@ -1948,16 +2000,7 @@ func looksLikeGeminiDailyQuota(message string) bool {
}
func nextGeminiDailyResetUnix() *int64 {
loc, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
// Fallback: PST without DST.
loc = time.FixedZone("PST", -8*3600)
}
now := time.Now().In(loc)
reset := time.Date(now.Year(), now.Month(), now.Day(), 0, 5, 0, 0, loc)
if !reset.After(now) {
reset = reset.Add(24 * time.Hour)
}
reset := geminiDailyResetTime(time.Now())
ts := reset.Unix()
return &ts
}
@@ -2278,11 +2321,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
"properties": map[string]any{},
}
}
// 清理 JSON Schema
cleanedParams := cleanToolSchema(params)
funcDecls = append(funcDecls, map[string]any{
"name": name,
"description": desc,
"parameters": params,
"parameters": cleanedParams,
})
}
@@ -2296,6 +2341,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
}
}
// cleanToolSchema 清理工具的 JSON Schema移除 Gemini 不支持的字段
func cleanToolSchema(schema any) any {
if schema == nil {
return nil
}
switch v := schema.(type) {
case map[string]any:
cleaned := make(map[string]any)
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}
// 递归清理嵌套对象
cleaned[key] = cleanToolSchema(value)
}
// 规范化 type 字段为大写
if typeVal, ok := cleaned["type"].(string); ok {
cleaned["type"] = strings.ToUpper(typeVal)
}
return cleaned
case []any:
cleaned := make([]any, len(v))
for i, item := range v {
cleaned[i] = cleanToolSchema(item)
}
return cleaned
default:
return v
}
}
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
out := make(map[string]any)
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {

View File

@@ -0,0 +1,128 @@
package service
import (
"testing"
)
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct {
name string
tools any
expectedLen int
description string
}{
{
name: "Standard tools",
tools: []any{
map[string]any{
"name": "get_weather",
"description": "Get weather info",
"input_schema": map[string]any{"type": "object"},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []any{
map[string]any{
"type": "custom",
"name": "mcp_tool",
"custom": map[string]any{
"description": "MCP tool description",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从custom字段读取",
},
{
name: "Mixed standard and custom tools",
tools: []any{
map[string]any{
"name": "standard_tool",
"description": "Standard",
"input_schema": map[string]any{"type": "object"},
},
map[string]any{
"type": "custom",
"name": "custom_tool",
"custom": map[string]any{
"description": "Custom",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "混合工具应该都能正确转换",
},
{
name: "Custom tool without custom field",
tools: []any{
map[string]any{
"type": "custom",
"name": "invalid_custom",
// 缺少 custom 字段
},
},
expectedLen: 0, // 应该被跳过
description: "缺少custom字段的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertClaudeToolsToGeminiTools(tt.tools)
if tt.expectedLen == 0 {
if result != nil {
t.Errorf("%s: expected nil result, got %v", tt.description, result)
}
return
}
if result == nil {
t.Fatalf("%s: expected non-nil result", tt.description)
}
if len(result) != 1 {
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
return
}
toolDecl, ok := result[0].(map[string]any)
if !ok {
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
}
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
if !ok {
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
}
toolsArr, _ := tt.tools.([]any)
expectedFuncCount := 0
for _, tool := range toolsArr {
toolMap, _ := tool.(map[string]any)
if toolMap["name"] != "" {
// 检查是否为有效的custom工具
if toolMap["type"] == "custom" {
if toolMap["custom"] != nil {
expectedFuncCount++
}
} else {
expectedFuncCount++
}
}
}
if len(funcDecls) != expectedFuncCount {
t.Errorf("%s: expected %d function declarations, got %d",
tt.description, expectedFuncCount, len(funcDecls))
}
})
}
}

View File

@@ -25,6 +25,16 @@ func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Acco
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForGemini) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
var result []*Account
for _, id := range ids {
if acc, ok := m.accountsByID[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForGemini) ExistsByID(ctx context.Context, id int64) (bool, error) {
if m.accountsByID == nil {
return false, nil

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"time"
@@ -16,6 +17,26 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
const (
TierAIPremium = "AI_PREMIUM"
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
TierFree = "FREE"
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
)
const (
GB = 1024 * 1024 * 1024
TB = 1024 * GB
StorageTierUnlimited = 100 * TB // 100TB
StorageTierAIPremium = 2 * TB // 2TB
StorageTierStandard = 200 * GB // 200GB
StorageTierBasic = 100 * GB // 100GB
StorageTierFree = 15 * GB // 15GB
)
type GeminiOAuthService struct {
sessionStore *geminicli.SessionStore
proxyRepo ProxyRepository
@@ -88,13 +109,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
// - ai_studio: requires a user-provided OAuth client.
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
if oauthType == "code_assist" || oauthType == "google_one" {
oauthCfg.ClientID = ""
oauthCfg.ClientSecret = ""
}
@@ -155,14 +177,152 @@ type GeminiExchangeCodeInput struct {
}
type GeminiTokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
}
// validateTierID validates tier_id format and length
func validateTierID(tierID string) error {
if tierID == "" {
return nil // Empty is allowed
}
if len(tierID) > 64 {
return fmt.Errorf("tier_id exceeds maximum length of 64 characters")
}
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) {
return fmt.Errorf("tier_id contains invalid characters")
}
return nil
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
tierID := "LEGACY"
// First pass: look for default tier
for _, tier := range allowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
// Second pass: if still LEGACY, take first non-empty tier
if tierID == "LEGACY" {
for _, tier := range allowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
return tierID
}
// inferGoogleOneTier infers Google One tier from Drive storage limit
func inferGoogleOneTier(storageBytes int64) string {
if storageBytes <= 0 {
return TierGoogleOneUnknown
}
if storageBytes > StorageTierUnlimited {
return TierGoogleOneUnlimited
}
if storageBytes >= StorageTierAIPremium {
return TierAIPremium
}
if storageBytes >= StorageTierStandard {
return TierGoogleOneStandard
}
if storageBytes >= StorageTierBasic {
return TierGoogleOneBasic
}
if storageBytes >= StorageTierFree {
return TierFree
}
return TierGoogleOneUnknown
}
// fetchGoogleOneTier fetches Google One tier from Drive API
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil {
// Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") {
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
return TierGoogleOneUnknown, nil, err
}
// Other errors
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
return TierGoogleOneUnknown, nil, err
}
tierID := inferGoogleOneTier(storageInfo.Limit)
return tierID, storageInfo, nil
}
// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier
func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
ctx context.Context,
account *Account,
) (tierID string, extra map[string]any, credentials map[string]any, err error) {
if account == nil {
return "", nil, nil, fmt.Errorf("account is nil")
}
// 验证账号类型
oauthType, ok := account.Credentials["oauth_type"].(string)
if !ok || oauthType != "google_one" {
return "", nil, nil, fmt.Errorf("not a google_one OAuth account")
}
// 获取 access_token
accessToken, ok := account.Credentials["access_token"].(string)
if !ok || accessToken == "" {
return "", nil, nil, fmt.Errorf("missing access_token")
}
// 获取 proxy URL
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 调用 Drive API
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL)
if err != nil {
return "", nil, nil, err
}
// 构建 extra 数据(保留原有 extra 字段)
extra = make(map[string]any)
for k, v := range account.Extra {
extra[k] = v
}
if storageInfo != nil {
extra["drive_storage_limit"] = storageInfo.Limit
extra["drive_storage_usage"] = storageInfo.Usage
extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339)
}
// 构建 credentials 数据
credentials = make(map[string]any)
for k, v := range account.Credentials {
credentials[k] = v
}
credentials["tier_id"] = tierID
return tierID, extra, credentials, nil
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
@@ -219,26 +379,78 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID)
// 计算过期时间减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
// 计算过期时间减去 5 分钟安全时间窗口考虑网络延迟和时钟偏差
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
projectID := sessionProjectID
var tierID string
// 对于 code_assist 模式project_id 是必需的
// 对于 code_assist 模式project_id 是必需的,需要调用 Code Assist API
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id配额由 Google 网关自动识别
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
if oauthType == "code_assist" {
switch oauthType {
case "code_assist":
if projectID == "" {
var err error
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
}
} else {
// 用户手动填了 project_id仍需调用 LoadCodeAssist 获取 tierID
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
} else {
tierID = fetchedTierID
}
}
if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
// tierID 缺失时使用默认值
if tierID == "" {
tierID = "LEGACY"
}
case "google_one":
// Attempt to fetch Drive storage tier
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// Log warning but don't block - use fallback
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
tierID = TierGoogleOneUnknown
}
// Store Drive info in extra field for caching
if storageInfo != nil {
tokenInfo := &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
Extra: map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
},
}
return tokenInfo, nil
}
}
// ai_studio 模式不设置 tierID保持为空
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
@@ -248,6 +460,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
}, nil
}
@@ -266,8 +479,15 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
if err == nil {
// 计算过期时间减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
// 计算过期时间减去 5 分钟安全时间窗口考虑网络延迟和时钟偏差
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
@@ -354,18 +574,75 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
tokenInfo.ProjectID = existingProjectID
}
// 尝试从账号凭证获取 tierID向后兼容
existingTierID := strings.TrimSpace(account.GetCredential("tier_id"))
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
switch oauthType {
case "code_assist":
// 先设置默认值或保留旧值,确保 tier_id 始终有值
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = "LEGACY" // 默认值
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
// 尝试自动探测 project_id 和 tier_id
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
if needDetect {
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err)
} else {
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
tokenInfo.ProjectID = projectID
}
// 只有当原来没有 tier_id 且探测成功时才更新
if existingTierID == "" && tierID != "" {
tokenInfo.TierID = tierID
}
}
}
if strings.TrimSpace(tokenInfo.ProjectID) == "" {
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
tokenInfo.ProjectID = projectID
case "google_one":
// Check if tier cache is stale (> 24 hours)
needsRefresh := true
if account.Extra != nil {
if updatedAtStr, ok := account.Extra["drive_tier_updated_at"].(string); ok {
if updatedAt, err := time.Parse(time.RFC3339, updatedAtStr); err == nil {
if time.Since(updatedAt) <= 24*time.Hour {
needsRefresh = false
// Use cached tier
if existingTierID != "" {
tokenInfo.TierID = existingTierID
}
}
}
}
}
if needsRefresh {
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
if err == nil && storageInfo != nil {
tokenInfo.TierID = tierID
tokenInfo.Extra = map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
}
} else {
// Fallback to cached or unknown
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = TierGoogleOneUnknown
}
}
}
}
return tokenInfo, nil
@@ -388,9 +665,22 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
if tokenInfo.TierID != "" {
// Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID
}
// Silently skip invalid tier_id (don't block account creation)
}
if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType
}
// Store extra metadata (Drive info) if present
if len(tokenInfo.Extra) > 0 {
for k, v := range tokenInfo.Extra {
creds[k] = v
}
}
return creds
}
@@ -398,33 +688,22 @@ func (s *GeminiOAuthService) Stop() {
s.sessionStore.Stop()
}
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) {
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) {
if s.codeAssist == nil {
return "", errors.New("code assist client not configured")
return "", "", errors.New("code assist client not configured")
}
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
}
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY"
if loadResp != nil {
for _, tier := range loadResp.AllowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
for _, tier := range loadResp.AllowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
}
// If LoadCodeAssist returned a project, use it
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
}
req := &geminicli.OnboardUserRequest{
@@ -443,39 +722,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
return strings.TrimSpace(fallback), tierID, nil
}
return "", err
return "", tierID, err
}
if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) {
case string:
return strings.TrimSpace(v), nil
return strings.TrimSpace(v), tierID, nil
case map[string]any:
if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil
return strings.TrimSpace(id), tierID, nil
}
}
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
return strings.TrimSpace(fallback), tierID, nil
}
return "", errors.New("onboardUser completed but no project_id returned")
return "", tierID, errors.New("onboardUser completed but no project_id returned")
}
time.Sleep(2 * time.Second)
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
return strings.TrimSpace(fallback), tierID, nil
}
if loadErr != nil {
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
}
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
}
type googleCloudProject struct {

View File

@@ -0,0 +1,51 @@
package service
import "testing"
func TestInferGoogleOneTier(t *testing.T) {
tests := []struct {
name string
storageBytes int64
expectedTier string
}{
{"Negative storage", -1, TierGoogleOneUnknown},
{"Zero storage", 0, TierGoogleOneUnknown},
// Free tier boundary (15GB)
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
{"Free tier (15GB)", StorageTierFree, TierFree},
// Basic tier boundary (100GB)
{"Between free and basic", 50 * GB, TierFree},
{"Just below basic tier", StorageTierBasic - 1, TierFree},
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
// Standard tier boundary (200GB)
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
// AI Premium tier boundary (2TB)
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
// Unlimited tier boundary (> 100TB)
{"Between premium and unlimited", 50 * TB, TierAIPremium},
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := inferGoogleOneTier(tt.storageBytes)
if result != tt.expectedTier {
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
tt.storageBytes, result, tt.expectedTier)
}
})
}
}

View File

@@ -0,0 +1,268 @@
package service
import (
"context"
"encoding/json"
"errors"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
type geminiModelClass string
const (
geminiModelPro geminiModelClass = "pro"
geminiModelFlash geminiModelClass = "flash"
)
type GeminiDailyQuota struct {
ProRPD int64
FlashRPD int64
}
type GeminiTierPolicy struct {
Quota GeminiDailyQuota
Cooldown time.Duration
}
type GeminiQuotaPolicy struct {
tiers map[string]GeminiTierPolicy
}
type GeminiUsageTotals struct {
ProRequests int64
FlashRequests int64
ProTokens int64
FlashTokens int64
ProCost float64
FlashCost float64
}
const geminiQuotaCacheTTL = time.Minute
type geminiQuotaOverrides struct {
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
}
type GeminiQuotaService struct {
cfg *config.Config
settingRepo SettingRepository
mu sync.Mutex
cachedAt time.Time
policy *GeminiQuotaPolicy
}
func NewGeminiQuotaService(cfg *config.Config, settingRepo SettingRepository) *GeminiQuotaService {
return &GeminiQuotaService{
cfg: cfg,
settingRepo: settingRepo,
}
}
func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if s == nil {
return newGeminiQuotaPolicy()
}
now := time.Now()
s.mu.Lock()
if s.policy != nil && now.Sub(s.cachedAt) < geminiQuotaCacheTTL {
policy := s.policy
s.mu.Unlock()
return policy
}
s.mu.Unlock()
policy := newGeminiQuotaPolicy()
if s.cfg != nil {
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
}
}
}
if s.settingRepo != nil {
value, err := s.settingRepo.GetValue(ctx, SettingKeyGeminiQuotaPolicy)
if err != nil && !errors.Is(err, ErrSettingNotFound) {
log.Printf("gemini quota: load setting failed: %v", err)
} else if strings.TrimSpace(value) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
}
}
}
s.mu.Lock()
s.policy = policy
s.cachedAt = now
s.mu.Unlock()
return policy
}
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
if account == nil || !account.IsGeminiCodeAssist() {
return GeminiDailyQuota{}, false
}
policy := s.Policy(ctx)
return policy.QuotaForTier(account.GeminiTierID())
}
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
policy := s.Policy(ctx)
return policy.CooldownForTier(tierID)
}
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
return &GeminiQuotaPolicy{
tiers: map[string]GeminiTierPolicy{
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
},
}
}
func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuotaConfig) {
if p == nil || len(tiers) == 0 {
return
}
for rawID, override := range tiers {
tierID := normalizeGeminiTierID(rawID)
if tierID == "" {
continue
}
policy, ok := p.tiers[tierID]
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
if override.ProRPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
}
if override.FlashRPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
}
if override.CooldownMinutes != nil {
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
policy.Cooldown = time.Duration(minutes) * time.Minute
}
p.tiers[tierID] = policy
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
policy, ok := p.policyForTier(tierID)
if !ok {
return GeminiDailyQuota{}, false
}
return policy.Quota, true
}
func (p *GeminiQuotaPolicy) CooldownForTier(tierID string) time.Duration {
policy, ok := p.policyForTier(tierID)
if ok && policy.Cooldown > 0 {
return policy.Cooldown
}
return 5 * time.Minute
}
func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool) {
if p == nil {
return GeminiTierPolicy{}, false
}
normalized := normalizeGeminiTierID(tierID)
if normalized == "" {
normalized = "LEGACY"
}
if policy, ok := p.tiers[normalized]; ok {
return policy, true
}
policy, ok := p.tiers["LEGACY"]
return policy, ok
}
func normalizeGeminiTierID(tierID string) string {
return strings.ToUpper(strings.TrimSpace(tierID))
}
func clampGeminiQuotaInt64(value int64) int64 {
if value < 0 {
return 0
}
return value
}
func clampGeminiQuotaInt(value int) int {
if value < 0 {
return 0
}
return value
}
func geminiCooldownForTier(tierID string) time.Duration {
policy := newGeminiQuotaPolicy()
return policy.CooldownForTier(tierID)
}
func geminiModelClassFromName(model string) geminiModelClass {
name := strings.ToLower(strings.TrimSpace(model))
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
return geminiModelFlash
}
return geminiModelPro
}
func geminiAggregateUsage(stats []usagestats.ModelStat) GeminiUsageTotals {
var totals GeminiUsageTotals
for _, stat := range stats {
switch geminiModelClassFromName(stat.Model) {
case geminiModelFlash:
totals.FlashRequests += stat.Requests
totals.FlashTokens += stat.TotalTokens
totals.FlashCost += stat.ActualCost
default:
totals.ProRequests += stat.Requests
totals.ProTokens += stat.TotalTokens
totals.ProCost += stat.ActualCost
}
}
return totals
}
func geminiQuotaLocation() *time.Location {
loc, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
return time.FixedZone("PST", -8*3600)
}
return loc
}
func geminiDailyWindowStart(now time.Time) time.Time {
loc := geminiQuotaLocation()
localNow := now.In(loc)
return time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
}
func geminiDailyResetTime(now time.Time) time.Time {
loc := geminiQuotaLocation()
localNow := now.In(loc)
start := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
reset := start.Add(24 * time.Hour)
if !reset.After(localNow) {
reset = reset.Add(24 * time.Hour)
}
return reset
}

View File

@@ -112,17 +112,21 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
if err != nil {
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
return accessToken, nil
}
detected = strings.TrimSpace(detected)
tierID = strings.TrimSpace(tierID)
if detected != "" {
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials["project_id"] = detected
if tierID != "" {
account.Credentials["tier_id"] = tierID
}
_ = p.accountRepo.Update(ctx, account)
}
}

View File

@@ -13,6 +13,7 @@ import (
"log"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
"time"
@@ -80,6 +81,7 @@ type OpenAIGatewayService struct {
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
concurrencyService *ConcurrencyService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
@@ -95,6 +97,7 @@ func NewOpenAIGatewayService(
userSubRepo UserSubscriptionRepository,
cache GatewayCache,
cfg *config.Config,
concurrencyService *ConcurrencyService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
@@ -108,6 +111,7 @@ func NewOpenAIGatewayService(
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
@@ -126,6 +130,14 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
return hex.EncodeToString(hash[:])
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
return nil
}
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL)
}
// SelectAccount selects an OpenAI account with sticky session support
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
@@ -218,6 +230,254 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
return selected, nil
}
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
accounts, err := s.listSchedulableAccounts(ctx, groupID)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// ============ Layer 1: Sticky session ============
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
// ============ Layer 2: Load-aware selection ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, false)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: Fallback wait ============
sortAccountsByPriorityAndLastUsed(candidates, false)
for _, acc := range candidates {
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
return accounts, nil
}
func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
// GetAccessToken gets the access token for an OpenAI account
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {

View File

@@ -5,6 +5,8 @@ import (
"log"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -12,15 +14,30 @@ import (
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
accountRepo AccountRepository
cfg *config.Config
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
type geminiUsageCacheEntry struct {
windowStart time.Time
cachedAt time.Time
totals GeminiUsageTotals
}
const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *RateLimitService {
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService {
return &RateLimitService{
accountRepo: accountRepo,
cfg: cfg,
accountRepo: accountRepo,
usageRepo: usageRepo,
cfg: cfg,
geminiQuotaService: geminiQuotaService,
usageCache: make(map[int64]*geminiUsageCacheEntry),
}
}
@@ -62,6 +79,106 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
}
// PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped.
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" {
return true, nil
}
if s.usageRepo == nil || s.geminiQuotaService == nil {
return true, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return true, nil
}
var limit int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
limit = quota.FlashRPD
default:
limit = quota.ProRPD
}
if limit <= 0 {
return true, nil
}
now := time.Now()
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return true, err
}
totals = geminiAggregateUsage(stats)
s.setGeminiUsageTotals(account.ID, start, now, totals)
}
var used int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
if used >= limit {
resetAt := geminiDailyResetTime(now)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt)
return false, nil
}
return true, nil
}
func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) {
s.usageCacheMu.RLock()
defer s.usageCacheMu.RUnlock()
if s.usageCache == nil {
return GeminiUsageTotals{}, false
}
entry, ok := s.usageCache[accountID]
if !ok || entry == nil {
return GeminiUsageTotals{}, false
}
if !entry.windowStart.Equal(windowStart) {
return GeminiUsageTotals{}, false
}
if now.Sub(entry.cachedAt) >= geminiPrecheckCacheTTL {
return GeminiUsageTotals{}, false
}
return entry.totals, true
}
func (s *RateLimitService) setGeminiUsageTotals(accountID int64, windowStart, now time.Time, totals GeminiUsageTotals) {
s.usageCacheMu.Lock()
defer s.usageCacheMu.Unlock()
if s.usageCache == nil {
s.usageCache = make(map[int64]*geminiUsageCacheEntry)
}
s.usageCache[accountID] = &geminiUsageCacheEntry{
windowStart: windowStart,
cachedAt: now,
totals: totals,
}
}
// GeminiCooldown returns the fallback cooldown duration for Gemini 429s based on tier.
func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) time.Duration {
if account == nil {
return 5 * time.Minute
}
return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID())
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {

View File

@@ -10,7 +10,6 @@ type User struct {
ID int64
Email string
Username string
Wechat string
Notes string
PasswordHash string
Role string

View File

@@ -0,0 +1,125 @@
package service
import (
"context"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// Error definitions for user attribute operations
var (
ErrAttributeDefinitionNotFound = infraerrors.NotFound("ATTRIBUTE_DEFINITION_NOT_FOUND", "attribute definition not found")
ErrAttributeKeyExists = infraerrors.Conflict("ATTRIBUTE_KEY_EXISTS", "attribute key already exists")
ErrInvalidAttributeType = infraerrors.BadRequest("INVALID_ATTRIBUTE_TYPE", "invalid attribute type")
ErrAttributeValidationFailed = infraerrors.BadRequest("ATTRIBUTE_VALIDATION_FAILED", "attribute value validation failed")
)
// UserAttributeType represents supported attribute types
type UserAttributeType string
const (
AttributeTypeText UserAttributeType = "text"
AttributeTypeTextarea UserAttributeType = "textarea"
AttributeTypeNumber UserAttributeType = "number"
AttributeTypeEmail UserAttributeType = "email"
AttributeTypeURL UserAttributeType = "url"
AttributeTypeDate UserAttributeType = "date"
AttributeTypeSelect UserAttributeType = "select"
AttributeTypeMultiSelect UserAttributeType = "multi_select"
)
// UserAttributeOption represents a select option for select/multi_select types
type UserAttributeOption struct {
Value string `json:"value"`
Label string `json:"label"`
}
// UserAttributeValidation represents validation rules for an attribute
type UserAttributeValidation struct {
MinLength *int `json:"min_length,omitempty"`
MaxLength *int `json:"max_length,omitempty"`
Min *int `json:"min,omitempty"`
Max *int `json:"max,omitempty"`
Pattern *string `json:"pattern,omitempty"`
Message *string `json:"message,omitempty"`
}
// UserAttributeDefinition represents a custom attribute definition
type UserAttributeDefinition struct {
ID int64
Key string
Name string
Description string
Type UserAttributeType
Options []UserAttributeOption
Required bool
Validation UserAttributeValidation
Placeholder string
DisplayOrder int
Enabled bool
CreatedAt time.Time
UpdatedAt time.Time
}
// UserAttributeValue represents a user's attribute value
type UserAttributeValue struct {
ID int64
UserID int64
AttributeID int64
Value string
CreatedAt time.Time
UpdatedAt time.Time
}
// CreateAttributeDefinitionInput for creating new definition
type CreateAttributeDefinitionInput struct {
Key string
Name string
Description string
Type UserAttributeType
Options []UserAttributeOption
Required bool
Validation UserAttributeValidation
Placeholder string
Enabled bool
}
// UpdateAttributeDefinitionInput for updating definition
type UpdateAttributeDefinitionInput struct {
Name *string
Description *string
Type *UserAttributeType
Options *[]UserAttributeOption
Required *bool
Validation *UserAttributeValidation
Placeholder *string
Enabled *bool
}
// UpdateUserAttributeInput for updating a single attribute value
type UpdateUserAttributeInput struct {
AttributeID int64
Value string
}
// UserAttributeDefinitionRepository interface for attribute definition persistence
type UserAttributeDefinitionRepository interface {
Create(ctx context.Context, def *UserAttributeDefinition) error
GetByID(ctx context.Context, id int64) (*UserAttributeDefinition, error)
GetByKey(ctx context.Context, key string) (*UserAttributeDefinition, error)
Update(ctx context.Context, def *UserAttributeDefinition) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, enabledOnly bool) ([]UserAttributeDefinition, error)
UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error
ExistsByKey(ctx context.Context, key string) (bool, error)
}
// UserAttributeValueRepository interface for user attribute value persistence
type UserAttributeValueRepository interface {
GetByUserID(ctx context.Context, userID int64) ([]UserAttributeValue, error)
GetByUserIDs(ctx context.Context, userIDs []int64) ([]UserAttributeValue, error)
UpsertBatch(ctx context.Context, userID int64, values []UpdateUserAttributeInput) error
DeleteByAttributeID(ctx context.Context, attributeID int64) error
DeleteByUserID(ctx context.Context, userID int64) error
}

View File

@@ -0,0 +1,295 @@
package service
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// UserAttributeService handles attribute management
type UserAttributeService struct {
defRepo UserAttributeDefinitionRepository
valueRepo UserAttributeValueRepository
}
// NewUserAttributeService creates a new service instance
func NewUserAttributeService(
defRepo UserAttributeDefinitionRepository,
valueRepo UserAttributeValueRepository,
) *UserAttributeService {
return &UserAttributeService{
defRepo: defRepo,
valueRepo: valueRepo,
}
}
// CreateDefinition creates a new attribute definition
func (s *UserAttributeService) CreateDefinition(ctx context.Context, input CreateAttributeDefinitionInput) (*UserAttributeDefinition, error) {
// Validate type
if !isValidAttributeType(input.Type) {
return nil, ErrInvalidAttributeType
}
// Check if key exists
exists, err := s.defRepo.ExistsByKey(ctx, input.Key)
if err != nil {
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
return nil, ErrAttributeKeyExists
}
def := &UserAttributeDefinition{
Key: input.Key,
Name: input.Name,
Description: input.Description,
Type: input.Type,
Options: input.Options,
Required: input.Required,
Validation: input.Validation,
Placeholder: input.Placeholder,
Enabled: input.Enabled,
}
if err := s.defRepo.Create(ctx, def); err != nil {
return nil, fmt.Errorf("create definition: %w", err)
}
return def, nil
}
// GetDefinition retrieves a definition by ID
func (s *UserAttributeService) GetDefinition(ctx context.Context, id int64) (*UserAttributeDefinition, error) {
return s.defRepo.GetByID(ctx, id)
}
// ListDefinitions lists all definitions
func (s *UserAttributeService) ListDefinitions(ctx context.Context, enabledOnly bool) ([]UserAttributeDefinition, error) {
return s.defRepo.List(ctx, enabledOnly)
}
// UpdateDefinition updates an existing definition
func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, input UpdateAttributeDefinitionInput) (*UserAttributeDefinition, error) {
def, err := s.defRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != nil {
def.Name = *input.Name
}
if input.Description != nil {
def.Description = *input.Description
}
if input.Type != nil {
if !isValidAttributeType(*input.Type) {
return nil, ErrInvalidAttributeType
}
def.Type = *input.Type
}
if input.Options != nil {
def.Options = *input.Options
}
if input.Required != nil {
def.Required = *input.Required
}
if input.Validation != nil {
def.Validation = *input.Validation
}
if input.Placeholder != nil {
def.Placeholder = *input.Placeholder
}
if input.Enabled != nil {
def.Enabled = *input.Enabled
}
if err := s.defRepo.Update(ctx, def); err != nil {
return nil, fmt.Errorf("update definition: %w", err)
}
return def, nil
}
// DeleteDefinition soft-deletes a definition and hard-deletes associated values
func (s *UserAttributeService) DeleteDefinition(ctx context.Context, id int64) error {
// Check if definition exists
_, err := s.defRepo.GetByID(ctx, id)
if err != nil {
return err
}
// First delete all values (hard delete)
if err := s.valueRepo.DeleteByAttributeID(ctx, id); err != nil {
return fmt.Errorf("delete values: %w", err)
}
// Then soft-delete the definition
if err := s.defRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete definition: %w", err)
}
return nil
}
// ReorderDefinitions updates display order for multiple definitions
func (s *UserAttributeService) ReorderDefinitions(ctx context.Context, orders map[int64]int) error {
return s.defRepo.UpdateDisplayOrders(ctx, orders)
}
// GetUserAttributes retrieves all attribute values for a user
func (s *UserAttributeService) GetUserAttributes(ctx context.Context, userID int64) ([]UserAttributeValue, error) {
return s.valueRepo.GetByUserID(ctx, userID)
}
// GetBatchUserAttributes retrieves attribute values for multiple users
// Returns a map of userID -> map of attributeID -> value
func (s *UserAttributeService) GetBatchUserAttributes(ctx context.Context, userIDs []int64) (map[int64]map[int64]string, error) {
values, err := s.valueRepo.GetByUserIDs(ctx, userIDs)
if err != nil {
return nil, err
}
result := make(map[int64]map[int64]string)
for _, v := range values {
if result[v.UserID] == nil {
result[v.UserID] = make(map[int64]string)
}
result[v.UserID][v.AttributeID] = v.Value
}
return result, nil
}
// UpdateUserAttributes batch updates attribute values for a user
func (s *UserAttributeService) UpdateUserAttributes(ctx context.Context, userID int64, inputs []UpdateUserAttributeInput) error {
// Validate all values before updating
defs, err := s.defRepo.List(ctx, true)
if err != nil {
return fmt.Errorf("list definitions: %w", err)
}
defMap := make(map[int64]*UserAttributeDefinition, len(defs))
for i := range defs {
defMap[defs[i].ID] = &defs[i]
}
for _, input := range inputs {
def, ok := defMap[input.AttributeID]
if !ok {
return ErrAttributeDefinitionNotFound
}
if err := s.validateValue(def, input.Value); err != nil {
return err
}
}
return s.valueRepo.UpsertBatch(ctx, userID, inputs)
}
// validateValue validates a value against its definition
func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value string) error {
// Skip validation for empty non-required fields
if value == "" && !def.Required {
return nil
}
// Required check
if def.Required && value == "" {
return validationError(fmt.Sprintf("%s is required", def.Name))
}
v := def.Validation
// String length validation
if v.MinLength != nil && len(value) < *v.MinLength {
return validationError(fmt.Sprintf("%s must be at least %d characters", def.Name, *v.MinLength))
}
if v.MaxLength != nil && len(value) > *v.MaxLength {
return validationError(fmt.Sprintf("%s must be at most %d characters", def.Name, *v.MaxLength))
}
// Number validation
if def.Type == AttributeTypeNumber && value != "" {
num, err := strconv.Atoi(value)
if err != nil {
return validationError(fmt.Sprintf("%s must be a number", def.Name))
}
if v.Min != nil && num < *v.Min {
return validationError(fmt.Sprintf("%s must be at least %d", def.Name, *v.Min))
}
if v.Max != nil && num > *v.Max {
return validationError(fmt.Sprintf("%s must be at most %d", def.Name, *v.Max))
}
}
// Pattern validation
if v.Pattern != nil && *v.Pattern != "" && value != "" {
re, err := regexp.Compile(*v.Pattern)
if err == nil && !re.MatchString(value) {
msg := def.Name + " format is invalid"
if v.Message != nil && *v.Message != "" {
msg = *v.Message
}
return validationError(msg)
}
}
// Select validation
if def.Type == AttributeTypeSelect && value != "" {
found := false
for _, opt := range def.Options {
if opt.Value == value {
found = true
break
}
}
if !found {
return validationError(fmt.Sprintf("%s: invalid option", def.Name))
}
}
// Multi-select validation (stored as JSON array)
if def.Type == AttributeTypeMultiSelect && value != "" {
var values []string
if err := json.Unmarshal([]byte(value), &values); err != nil {
// Try comma-separated fallback
values = strings.Split(value, ",")
}
for _, val := range values {
val = strings.TrimSpace(val)
found := false
for _, opt := range def.Options {
if opt.Value == val {
found = true
break
}
}
if !found {
return validationError(fmt.Sprintf("%s: invalid option %s", def.Name, val))
}
}
}
return nil
}
// validationError creates a validation error with a custom message
func validationError(msg string) error {
return infraerrors.BadRequest("ATTRIBUTE_VALIDATION_FAILED", msg)
}
func isValidAttributeType(t UserAttributeType) bool {
switch t {
case AttributeTypeText, AttributeTypeTextarea, AttributeTypeNumber,
AttributeTypeEmail, AttributeTypeURL, AttributeTypeDate,
AttributeTypeSelect, AttributeTypeMultiSelect:
return true
}
return false
}

View File

@@ -14,6 +14,14 @@ var (
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
)
// UserListFilters contains all filter options for listing users
type UserListFilters struct {
Status string // User status filter
Role string // User role filter
Search string // Search in email, username
Attributes map[int64]string // Custom attribute filters: attributeID -> value
}
type UserRepository interface {
Create(ctx context.Context, user *User) error
GetByID(ctx context.Context, id int64) (*User, error)
@@ -23,7 +31,7 @@ type UserRepository interface {
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
@@ -36,7 +44,6 @@ type UserRepository interface {
type UpdateProfileRequest struct {
Email *string `json:"email"`
Username *string `json:"username"`
Wechat *string `json:"wechat"`
Concurrency *int `json:"concurrency"`
}
@@ -100,10 +107,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Username = *req.Username
}
if req.Wechat != nil {
user.Wechat = *req.Wechat
}
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}

View File

@@ -73,6 +73,15 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
return svc
}
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache)
if cfg != nil {
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
}
return svc
}
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
@@ -94,6 +103,7 @@ var ProviderSet = wire.NewSet(
NewOAuthService,
NewOpenAIOAuthService,
NewGeminiOAuthService,
NewGeminiQuotaService,
NewAntigravityOAuthService,
NewGeminiTokenProvider,
NewGeminiMessagesCompatService,
@@ -107,7 +117,7 @@ var ProviderSet = wire.NewSet(
ProvideEmailQueueService,
NewTurnstileService,
NewSubscriptionService,
NewConcurrencyService,
ProvideConcurrencyService,
NewIdentityService,
NewCRSSyncService,
ProvideUpdateService,
@@ -115,4 +125,5 @@ var ProviderSet = wire.NewSet(
ProvideTimingWheelService,
ProvideDeferredService,
ProvideAntigravityQuotaRefresher,
NewUserAttributeService,
)

View File

@@ -0,0 +1,48 @@
-- Add user attribute definitions and values tables for custom user attributes.
-- User Attribute Definitions table (with soft delete support)
CREATE TABLE IF NOT EXISTS user_attribute_definitions (
id BIGSERIAL PRIMARY KEY,
key VARCHAR(100) NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT DEFAULT '',
type VARCHAR(20) NOT NULL,
options JSONB DEFAULT '[]'::jsonb,
required BOOLEAN NOT NULL DEFAULT FALSE,
validation JSONB DEFAULT '{}'::jsonb,
placeholder VARCHAR(255) DEFAULT '',
display_order INT NOT NULL DEFAULT 0,
enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
-- Partial unique index for key (only for non-deleted records)
-- Allows reusing keys after soft delete
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_attribute_definitions_key_unique
ON user_attribute_definitions(key) WHERE deleted_at IS NULL;
CREATE INDEX IF NOT EXISTS idx_user_attribute_definitions_enabled
ON user_attribute_definitions(enabled);
CREATE INDEX IF NOT EXISTS idx_user_attribute_definitions_display_order
ON user_attribute_definitions(display_order);
CREATE INDEX IF NOT EXISTS idx_user_attribute_definitions_deleted_at
ON user_attribute_definitions(deleted_at);
-- User Attribute Values table (hard delete only, no deleted_at)
CREATE TABLE IF NOT EXISTS user_attribute_values (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
attribute_id BIGINT NOT NULL REFERENCES user_attribute_definitions(id) ON DELETE CASCADE,
value TEXT DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(user_id, attribute_id)
);
CREATE INDEX IF NOT EXISTS idx_user_attribute_values_user_id
ON user_attribute_values(user_id);
CREATE INDEX IF NOT EXISTS idx_user_attribute_values_attribute_id
ON user_attribute_values(attribute_id);

View File

@@ -0,0 +1,83 @@
-- Migration: Move wechat field from users table to user_attribute_values
-- This migration:
-- 1. Creates a "wechat" attribute definition
-- 2. Migrates existing wechat data to user_attribute_values
-- 3. Does NOT drop the wechat column (for rollback safety, can be done in a later migration)
-- +goose Up
-- +goose StatementBegin
-- Step 1: Insert wechat attribute definition if not exists
INSERT INTO user_attribute_definitions (key, name, description, type, options, required, validation, placeholder, display_order, enabled, created_at, updated_at)
SELECT 'wechat', '微信', '用户微信号', 'text', '[]'::jsonb, false, '{}'::jsonb, '请输入微信号', 0, true, NOW(), NOW()
WHERE NOT EXISTS (
SELECT 1 FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL
);
-- Step 2: Migrate existing wechat values to user_attribute_values
-- Only migrate non-empty values
INSERT INTO user_attribute_values (user_id, attribute_id, value, created_at, updated_at)
SELECT
u.id,
(SELECT id FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL LIMIT 1),
u.wechat,
NOW(),
NOW()
FROM users u
WHERE u.wechat IS NOT NULL
AND u.wechat != ''
AND u.deleted_at IS NULL
AND NOT EXISTS (
SELECT 1 FROM user_attribute_values uav
WHERE uav.user_id = u.id
AND uav.attribute_id = (SELECT id FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL LIMIT 1)
);
-- Step 3: Update display_order to ensure wechat appears first
UPDATE user_attribute_definitions
SET display_order = -1
WHERE key = 'wechat' AND deleted_at IS NULL;
-- Reorder all attributes starting from 0
WITH ordered AS (
SELECT id, ROW_NUMBER() OVER (ORDER BY display_order, id) - 1 as new_order
FROM user_attribute_definitions
WHERE deleted_at IS NULL
)
UPDATE user_attribute_definitions
SET display_order = ordered.new_order
FROM ordered
WHERE user_attribute_definitions.id = ordered.id;
-- Step 4: Drop the redundant wechat column from users table
ALTER TABLE users DROP COLUMN IF EXISTS wechat;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- Restore wechat column
ALTER TABLE users ADD COLUMN IF NOT EXISTS wechat VARCHAR(100) DEFAULT '';
-- Copy attribute values back to users.wechat column
UPDATE users u
SET wechat = uav.value
FROM user_attribute_values uav
JOIN user_attribute_definitions uad ON uav.attribute_id = uad.id
WHERE uav.user_id = u.id
AND uad.key = 'wechat'
AND uad.deleted_at IS NULL;
-- Delete migrated attribute values
DELETE FROM user_attribute_values
WHERE attribute_id IN (
SELECT id FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL
);
-- Soft-delete the wechat attribute definition
UPDATE user_attribute_definitions
SET deleted_at = NOW()
WHERE key = 'wechat' AND deleted_at IS NULL;
-- +goose StatementEnd

Some files were not shown because too many files have changed in this diff Show More